* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#ifndef __MSOPPROF_NUMBER_OPERATION_H__
#define __MSOPPROF_NUMBER_OPERATION_H__
#include <cmath>
#include <limits>
#include <string>
#include "log.h"
namespace Utility {
template<typename T>
inline typename std::enable_if<std::is_floating_point<T>::value, bool>::type IsZero(T num)
{
return std::abs(num) < std::numeric_limits<T>::epsilon();
}
template<typename T>
inline typename std::enable_if<std::is_integral<T>::value, bool>::type IsZero(T num)
{
return num == 0;
}
template<typename T>
bool SafeEqual(T a, T b, T precision = 0,
typename std::enable_if<
!std::is_floating_point<T>::value,
void>::type* = nullptr)
{
(void)precision;
return a == b;
}
template<typename T>
bool SafeEqual(T a, T b, T precision = 0.0001,
typename std::enable_if<
std::is_floating_point<T>::value,
void>::type* = nullptr)
{
return (std::fabs(a - b) <= precision);
}
template<typename T>
T OverFlowReturnValue(T num, bool returnValue)
{
if (!returnValue) {
return 0;
}
if (std::signbit(num)) {
return std::numeric_limits<T>::lowest();
}
return std::numeric_limits<T>::max();
}
template<typename T>
T SafeAdd(T num1, T num2, const std::string &location, bool returnValue = true)
{
T maxValue = std::numeric_limits<T>::max();
T minValue = std::numeric_limits<T>::lowest();
if (std::is_signed<T>::value) {
if (num1 > 0 && num2 > 0 && num1 > maxValue - num2) {
LogDebug("Signed addition overflow at %s, num1=%ld, num2=%ld", location.c_str(), (long)num1, (long)num2);
return OverFlowReturnValue(maxValue, returnValue);
}
if (num1 < 0 && num2 < 0 && num1 < minValue - num2) {
LogDebug("Signed addition underflow at %s, num1=%ld, num2=%ld", location.c_str(), (long)num1, (long)num2);
return OverFlowReturnValue(minValue, returnValue);
}
} else {
if (num1 > maxValue - num2) {
LogDebug("Unsigned addition overflow at %s, num1=%lu, num2=%lu", location.c_str(), (unsigned long)num1, (unsigned long)num2);
return OverFlowReturnValue(maxValue, returnValue);
}
}
return num1 + num2;
}
template<typename T>
T SafeSub(T num1, T num2, const std::string &location, bool returnValue = true)
{
T minValue = std::numeric_limits<T>::lowest();
T maxValue = std::numeric_limits<T>::max();
if (std::is_signed<T>::value) {
if (num1 >= 0 && num2 < 0 && num1 > maxValue + num2) {
LogDebug("Signed subtraction overflow at %s, num1=%ld, num2=%ld", location.c_str(), (long)num1, (long)num2);
return OverFlowReturnValue(maxValue, returnValue);
}
if (num1 < 0 && num2 > 0 && num1 < minValue + num2) {
LogDebug("Signed subtraction underflow at %s, num1=%ld, num2=%ld", location.c_str(), (long)num1, (long)num2);
return OverFlowReturnValue(minValue, returnValue);
}
} else {
if (num1 < num2) {
LogDebug("Unsigned subtraction overflow at %s, num1=%lu, num2=%lu", location.c_str(), (unsigned long)num1, (unsigned long)num2);
return OverFlowReturnValue(maxValue, returnValue);
}
}
return num1 - num2;
}
template<typename T>
bool HaveSameSign(T num1, T num2)
{
return std::signbit(num1) == std::signbit(num2);
}
template<typename T>
bool CheckOverFlow(T num)
{
if (std::is_signed<T>::value) {
if (SafeEqual(num, std::numeric_limits<T>::max()) || SafeEqual(num, std::numeric_limits<T>::lowest())) {
return true;
}
} else {
if (SafeEqual(num, std::numeric_limits<T>::max())) {
return true;
}
}
return false;
}
template<typename T>
T SafeMul(T num1, T num2, const std::string &location, bool returnValue = true)
{
if (SafeEqual(num1, T(0)) || SafeEqual(num2, T(0))) {
return 0;
}
T minValue = std::numeric_limits<T>::lowest();
T maxValue = std::numeric_limits<T>::max();
T res = num1 * num2;
if (std::is_same<T, float>::value) {
if (!SafeEqual(res / num1, num2)) {
LogDebug("Float multiplication overflow at %s, num1=%f, num2=%f", location.c_str(), (double)num1, (double)num2);
if (HaveSameSign(num1, num2)) {
return OverFlowReturnValue(maxValue, returnValue);
}
return OverFlowReturnValue(minValue, returnValue);
}
} else if (!SafeEqual(res / num1, num2)) {
LogDebug("Integer multiplication overflow at %s, num1=%ld, num2=%ld", location.c_str(), (long)num1, (long)num2);
if (HaveSameSign(num1, num2)) {
return OverFlowReturnValue(maxValue, returnValue);
}
return OverFlowReturnValue(minValue, returnValue);
}
return res;
}
template<typename T>
T SafeAddAll(const std::vector<T> &nums, const std::string &location, bool returnValue = true)
{
T sum = 0;
for (const auto& num : nums) {
sum = SafeAdd(sum, num, location, returnValue);
if (CheckOverFlow(sum)) {
return sum;
}
}
return sum;
}
template<typename T>
T SafeMulAll(const std::vector<T>& nums, const std::string &location, bool returnValue = true)
{
T res = 1;
for (const auto& num : nums) {
if (SafeEqual(num, T(0))) {
return 0;
}
res = SafeMul(res, num, location, returnValue);
if (CheckOverFlow(res)) {
res = OverFlowReturnValue(res, returnValue);
return res;
}
}
return res;
}
template<typename T1, typename T2>
T1 SafeMulAddAll(const std::vector<T1>& nums1, const std::vector<T2>& nums2, const std::string &location,
bool returnValue = true)
{
if (nums1.size() != nums2.size()) {
LogDebug("The multiplication and addition all function receive wrong param size in %s", location.c_str());
return std::numeric_limits<T1>::max();
}
T1 res = 0;
for (size_t i = 0; i < nums1.size(); i++) {
T1 temp = SafeMul(nums1[i], static_cast<T1>(nums2[i]), location, returnValue);
res = SafeAdd(res, temp, location, returnValue);
if (CheckOverFlow(res)) {
return res;
}
}
return res;
}
inline uint64_t ExtractKBits(uint64_t regValue, uint32_t startBit, uint32_t bitLength)
{
constexpr uint32_t bitSize = 64;
if (startBit >= bitSize || bitLength >= bitSize) {
return 0;
}
return (regValue >> startBit) & ((1ULL << bitLength) - 1);
}
}
#endif