* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "rsqrt_aicpu.h"
#include <cfloat>
#include <complex>
#include <unsupported/Eigen/CXX11/Tensor>
#include "cpu_kernel_utils.h"
#include "cpu_types.h"
#include "log.h"
#include "status.h"
#include "utils/kernel_util.h"
namespace {
const char *const kRsqrt = "Rsqrt";
const std::uint32_t kRsqrtInputNum = 1;
const std::uint32_t kRsqrtOutputNum = 1;
constexpr int64_t kParallelDataNums = 8 * 1024;
constexpr int64_t kParallelComplexDataNums = 4 * 1024;
}
namespace aicpu {
namespace {
int64_t GetMaxCoreNum(const CpuKernelContext &ctx, int64_t dataNum) {
const std::uint32_t min_core_num = 1;
int64_t max_core_num = std::max(static_cast<int64_t>(min_core_num),
static_cast<int64_t>(CpuKernelUtils::GetCPUNum(ctx)) - kResvCpuNum);
return std::min(max_core_num, dataNum);
}
}
template <typename T>
std::uint32_t RsqrtCpuKernel::RsqrtCompute(const Tensor *x, const Tensor *y, int64_t dataNum,
const CpuKernelContext &ctx) const {
auto inputx = reinterpret_cast<T *>(x->GetData());
KERNEL_CHECK_NULLPTR(inputx, static_cast<std::uint32_t>(KERNEL_STATUS_PARAM_INVALID), "Get input data failed")
auto outputy = reinterpret_cast<T *>(y->GetData());
KERNEL_CHECK_NULLPTR(outputy, static_cast<std::uint32_t>(KERNEL_STATUS_PARAM_INVALID), "Get output data failed")
if (dataNum <= kParallelDataNums) {
for (int64_t i = 0; i < dataNum; i++) {
outputy[i] = static_cast<T>(1) / sqrt(inputx[i]);
}
} else {
int64_t max_core_num = GetMaxCoreNum(ctx, dataNum);
auto shard_rsqrt = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
outputy[i] = static_cast<T>(1) / sqrt(inputx[i]);
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, dataNum, dataNum / max_core_num, shard_rsqrt),
"Rsqrt Compute failed.");
}
return static_cast<std::uint32_t>(KERNEL_STATUS_OK);
}
template <typename T>
std::uint32_t RsqrtCpuKernel::RsqrtComputeComplex(const Tensor *x, const Tensor *y, int64_t dataNum,
const CpuKernelContext &ctx) const {
auto inputx = reinterpret_cast<T *>(x->GetData());
KERNEL_CHECK_NULLPTR(inputx, static_cast<std::uint32_t>(KERNEL_STATUS_PARAM_INVALID), "Get input data failed")
auto outputy = reinterpret_cast<T *>(y->GetData());
KERNEL_CHECK_NULLPTR(outputy, static_cast<std::uint32_t>(KERNEL_STATUS_PARAM_INVALID), "Get output data failed")
if (dataNum <= kParallelComplexDataNums) {
for (int64_t i = 0; i < dataNum; i++) {
outputy[i] =
sqrt(conj(inputx[i])) / sqrt(inputx[i].real() * inputx[i].real() + inputx[i].imag() * inputx[i].imag());
}
} else {
int64_t max_core_num = GetMaxCoreNum(ctx, dataNum);
auto shard_rsqrt = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
outputy[i] =
sqrt(conj(inputx[i])) / sqrt(inputx[i].real() * inputx[i].real() + inputx[i].imag() * inputx[i].imag());
}
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, dataNum, dataNum / max_core_num, shard_rsqrt),
"Rsqrt Compute failed.");
}
return static_cast<std::uint32_t>(KERNEL_STATUS_OK);
}
std::uint32_t RsqrtCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kRsqrtInputNum, kRsqrtOutputNum), "Check Rsqrt params failed.");
if (ctx.Input(0)->GetDataType() != ctx.Output(0)->GetDataType()) {
KERNEL_LOG_ERROR("The data type of the input [%s] need be the same as the output [%s]",
DTypeStr(ctx.Input(0)->GetDataType()).c_str(),
DTypeStr(ctx.Output(0)->GetDataType()).c_str());
return static_cast<std::uint32_t>(KERNEL_STATUS_PARAM_INVALID);
}
if (ctx.Input(0)->GetDataSize() != ctx.Output(0)->GetDataSize()) {
KERNEL_LOG_ERROR("The data size of the input [%lu] need be the same as the output [%lu]",
ctx.Input(0)->GetDataSize(), ctx.Output(0)->GetDataSize());
return static_cast<std::uint32_t>(KERNEL_STATUS_PARAM_INVALID);
}
const Tensor *x = ctx.Input(0);
const Tensor *y = ctx.Output(0);
int64_t dataNum = x->NumElements();
DataType datatype = x->GetDataType();
std::uint32_t res = static_cast<std::uint32_t>(KERNEL_STATUS_OK);
switch (datatype) {
case DT_FLOAT16:
res = RsqrtCompute<Eigen::half>(x, y, dataNum, ctx);
break;
case DT_FLOAT:
res = RsqrtCompute<float>(x, y, dataNum, ctx);
break;
case DT_DOUBLE:
res = RsqrtCompute<double>(x, y, dataNum, ctx);
break;
case DT_COMPLEX64:
res = RsqrtComputeComplex<std::complex<float>>(x, y, dataNum, ctx);
break;
case DT_COMPLEX128:
res = RsqrtComputeComplex<std::complex<double>>(x, y, dataNum, ctx);
break;
default:
KERNEL_LOG_ERROR("Rsqrt invalid input type [%s]", DTypeStr(datatype).c_str());
return static_cast<std::uint32_t>(KERNEL_STATUS_PARAM_INVALID);
}
if (res != static_cast<std::uint32_t>(KERNEL_STATUS_OK)) {
return static_cast<std::uint32_t>(KERNEL_STATUS_INNER_ERROR);
}
return static_cast<std::uint32_t>(KERNEL_STATUS_OK);
}
REGISTER_CPU_KERNEL(kRsqrt, RsqrtCpuKernel);
}