* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "round_aicpu.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include "cpu_kernel_utils.h"
#include "utils/kernel_util.h"
#include "cpu_types.h"
#include "log.h"
#include "status.h"
namespace {
const char *const kRound = "Round";
bool IsValueEqualForDouble(double a, double b) {
bool has_inf = std::isinf(a) || std::isinf(b);
bool has_nan = std::isnan(a) || std::isnan(b);
if (has_inf || has_nan) {
return a == b;
}
constexpr double epsilon = 1e-14;
return std::abs(a - b) <= epsilon * std::max(std::abs(a), std::abs(b));
}
template <typename T>
inline auto DoCompute(const T &x) -> const T {
T roundval = Eigen::numext::floor(x);
const T fraction = x - roundval;
const T point_five = T(.5);
bool is_point_five = std::is_same<T, double>::value ?
IsValueEqualForDouble(fraction, point_five) : aicpu::IsValueEqual<T>(fraction, point_five);
if (fraction > point_five) {
roundval += T(1.0);
} else if (is_point_five) {
const T nearest_even_int =
roundval - T(2) * Eigen::numext::floor(point_five * x);
bool is_odd = std::is_same<T, double>::value ?
IsValueEqualForDouble(nearest_even_int, T(1)) : aicpu::IsValueEqual<T>(nearest_even_int, T(1));
if (is_odd) {
roundval += T(1);
}
}
return roundval;
}
template <typename T>
inline auto ScalarRoundWithDecimals(const T &x, const T &pow_decimals, const bool &neg_flag) -> const T {
T roundval = neg_flag ? DoCompute<T>(x / pow_decimals) : DoCompute<T>(x * pow_decimals);
return neg_flag ? roundval * pow_decimals : roundval / pow_decimals;
}
}
namespace aicpu {
template <typename T>
uint32_t RangeRound(CpuKernelContext &ctx, T *input, T *out, int64_t decimals) {
uint32_t ret = 0;
uint32_t max_core_num = std::max(1U, aicpu::CpuKernelUtils::GetCPUNum(ctx));
int64_t input_datasize = ctx.Input(0)->NumElements();
auto per_unit_size = CeilMultiple(input_datasize, max_core_num);
if (decimals == 0) {
auto shard_copy = [&input, &out](int64_t start, int64_t end) {
constexpr bool isInt = (Eigen::NumTraits<T>::IsInteger != 0);
if (isInt) {
for (int64_t i = start; i < end; ++i) {
out[i] = input[i];
}
} else {
for (int64_t i = start; i < end; ++i) {
out[i] = DoCompute<T>(input[i]);
}
}
};
ret = CpuKernelUtils::ParallelFor(ctx, input_datasize, per_unit_size, shard_copy);
} else {
bool neg_flag = false;
if (decimals < 0) {
decimals = -decimals;
neg_flag = true;
}
const T pow_decimals = static_cast<T>(std::pow(10, decimals));
auto shard_copy = [&input, &out, &pow_decimals, &neg_flag](int64_t start, int64_t end) {
for (int64_t i = start; i < end; ++i) {
out[i] = ScalarRoundWithDecimals<T>(input[i], pow_decimals, neg_flag);
}
};
ret = CpuKernelUtils::ParallelFor(ctx, input_datasize, per_unit_size, shard_copy);
}
return ret;
}
bool RoundCpuKernel::CheckSupported(DataType input_type) const {
switch (input_type) {
case DT_FLOAT16:
case DT_FLOAT:
case DT_DOUBLE:
case DT_INT32:
case DT_INT64:
return true;
default:
KERNEL_LOG_ERROR("Unsupported input data type[%d]", static_cast<int32_t>(input_type));
return false;
}
}
uint32_t RoundCpuKernel::Compute(CpuKernelContext &ctx) {
Tensor *input_tensor = ctx.Input(0);
KERNEL_CHECK_NULLPTR(input_tensor, KERNEL_STATUS_PARAM_INVALID, "Get input[0] failed")
Tensor *output_tensor = ctx.Output(0);
KERNEL_CHECK_NULLPTR(output_tensor, KERNEL_STATUS_PARAM_INVALID, "Get output[0] failed")
auto inputdata = input_tensor->GetData();
KERNEL_CHECK_NULLPTR(inputdata, KERNEL_STATUS_PARAM_INVALID, "Get input[0] data failed")
auto outputdata = output_tensor->GetData();
KERNEL_CHECK_NULLPTR(outputdata, KERNEL_STATUS_PARAM_INVALID, "Get output[0] data failed")
auto decimals_attr = ctx.GetAttr("decimals");
int64_t decimals = 0;
if (decimals_attr != nullptr) {
decimals = decimals_attr->GetInt();
}
DataType inputtype = input_tensor->GetDataType();
if (!CheckSupported(inputtype)) {
return KERNEL_STATUS_PARAM_INVALID;
}
uint32_t ret = 0;
switch (inputtype) {
case DT_FLOAT16:
ret = RangeRound(ctx, static_cast<Eigen::half *>(inputdata),
static_cast<Eigen::half *>(outputdata), decimals);
break;
case DT_FLOAT:
ret = RangeRound(ctx, static_cast<float *>(inputdata),
static_cast<float *>(outputdata), decimals);
break;
case DT_DOUBLE:
ret = RangeRound(ctx, static_cast<double *>(inputdata),
static_cast<double *>(outputdata), decimals);
break;
case DT_INT32:
ret = RangeRound(ctx, static_cast<int32_t *>(inputdata),
static_cast<int32_t *>(outputdata), decimals);
break;
case DT_INT64:
ret = RangeRound(ctx, static_cast<int64_t *>(inputdata),
static_cast<int64_t *>(outputdata), decimals);
break;
default:
KERNEL_LOG_ERROR("Unsupported input data type[%d]", static_cast<int32_t>(inputtype));
return KERNEL_STATUS_PARAM_INVALID;
}
return ret;
}
REGISTER_CPU_KERNEL(kRound, RoundCpuKernel);
}