* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "sqrt_aicpu.h"
#include <complex>
#include "cpu_kernel_utils.h"
#include "utils/kernel_util.h"
using namespace std;
namespace {
const char *const kSqrt = "Sqrt";
const uint32_t kInputNum = 1;
const uint32_t kOutputNum = 1;
constexpr int64_t kParallelDataNums = 8 * 1024;
}
namespace aicpu {
template <typename T>
uint32_t SqrtCpuKernel::DoComputeReal(const CpuKernelContext &ctx) {
auto *input = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto *output = reinterpret_cast<T *>(ctx.Output(0)->GetData());
int64_t data_num = ctx.Input(0)->NumElements();
if (data_num <= kParallelDataNums) {
Eigen::TensorMap<Eigen::Tensor<T, 1>, Eigen::Unaligned> tensor_x(input, data_num);
Eigen::TensorMap<Eigen::Tensor<T, 1>, Eigen::Unaligned> tensor_y(output, data_num);
tensor_y = tensor_x.sqrt();
} else {
uint32_t min_core_num = 1;
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
max_core_num = max_core_num > data_num ? data_num : max_core_num;
auto shard_sqrt = [&input, &output](int64_t begin, int64_t end) {
int64_t length = end - begin;
Eigen::TensorMap<Eigen::Tensor<T, 1>, Eigen::Unaligned> tensor_x(input + begin, length);
Eigen::TensorMap<Eigen::Tensor<T, 1>, Eigen::Unaligned> tensor_y(output + begin, length);
tensor_y = tensor_x.sqrt();
};
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shard_sqrt),
"Sqrt Compute failed.");
}
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t SqrtCpuKernel::DoComputeComplex(const CpuKernelContext &ctx) {
auto *input = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto *output = reinterpret_cast<T *>(ctx.Output(0)->GetData());
int64_t data_num = ctx.Input(0)->NumElements();
auto shard_sqrt = [&input, &output](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
output[i] = std::sqrt(input[i]);
}
};
if (data_num <= kParallelDataNums) {
shard_sqrt(0, data_num);
} else {
uint32_t min_core_num = 1;
int64_t max_core_num = std::max(min_core_num, aicpu::CpuKernelUtils::GetCPUNum(ctx) - kResvCpuNum);
max_core_num = max_core_num > data_num ? data_num : max_core_num;
KERNEL_HANDLE_ERROR(CpuKernelUtils::ParallelFor(ctx, data_num, data_num / max_core_num, shard_sqrt),
"Sqrt Compute failed.");
}
return KERNEL_STATUS_OK;
}
uint32_t SqrtCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kInputNum, kOutputNum), "Check Sqrt params failed.");
DataType input_type = ctx.Input(0)->GetDataType();
DataType output_type = ctx.Output(0)->GetDataType();
KERNEL_CHECK_FALSE((input_type == output_type), KERNEL_STATUS_PARAM_INVALID,
"The data type of input [%s] must be the same as output [%s].",
DTypeStr(input_type).c_str(), DTypeStr(output_type).c_str());
KERNEL_CHECK_FALSE((ctx.Input(0)->GetDataSize() == ctx.Output(0)->GetDataSize()), KERNEL_STATUS_PARAM_INVALID,
"The data size of input [%lu] must be the same as output [%lu].",
ctx.Input(0)->GetDataSize(), ctx.Output(0)->GetDataSize());
KERNEL_LOG_DEBUG("%s op input[x] data type is [%s].", kSqrt, DTypeStr(input_type).c_str());
switch (input_type) {
case DT_FLOAT16:
return DoComputeReal<Eigen::half>(ctx);
case DT_FLOAT:
return DoComputeReal<float>(ctx);
case DT_DOUBLE:
return DoComputeReal<double>(ctx);
case DT_COMPLEX64:
return DoComputeComplex<complex<float>>(ctx);
case DT_COMPLEX128:
return DoComputeComplex<complex<double>>(ctx);
default:
KERNEL_LOG_ERROR("Sqrt invalid input type [%s].", DTypeStr(input_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
}
REGISTER_CPU_KERNEL(kSqrt, SqrtCpuKernel);
}