* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <gtest/gtest.h>
#include <Eigen/Core>
#include <cmath>
#include <complex>
#include "utils/aicpu_test_utils.h"
#include "cpu_kernel_utils.h"
#include "node_def_builder.h"
using namespace std;
using namespace aicpu;
class TEST_RSQRT_UT : public testing::Test {};
#define CREATE_NODEDEF(shapes, data_types, datas) \
auto node_def = CpuKernelUtils::CreateNodeDef(); \
NodeDefBuilder(node_def.get(), "Rsqrt", "Rsqrt") \
.Input({"x", data_types[0], shapes[0], datas[0]}) \
.Output({"y", data_types[1], shapes[1], datas[1]})
TEST_F(TEST_RSQRT_UT, DATA_TYPE_FLOAT_SUCC)
{
vector<DataType> data_types = {DT_FLOAT, DT_FLOAT};
vector<vector<int64_t>> shapes = {{5, 5}, {5, 5}};
float input[25];
for (int i = 0; i < 25; i++) {
input[i] = static_cast<float>(i + 1);
}
float output[25] = {0.0f};
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_OK);
float output_exp[25];
for (int i = 0; i < 25; i++) {
output_exp[i] = 1.0f / std::sqrt(static_cast<float>(i + 1));
}
EXPECT_EQ(CompareResult(output, output_exp, 25), true);
}
TEST_F(TEST_RSQRT_UT, DATA_TYPE_DOUBLE_SUCC)
{
vector<DataType> data_types = {DT_DOUBLE, DT_DOUBLE};
vector<vector<int64_t>> shapes = {{5, 5}, {5, 5}};
double input[25];
for (int i = 0; i < 25; i++) {
input[i] = static_cast<double>(i + 1);
}
double output[25] = {0.0};
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_OK);
double output_exp[25];
for (int i = 0; i < 25; i++) {
output_exp[i] = 1.0 / std::sqrt(static_cast<double>(i + 1));
}
EXPECT_EQ(CompareResult(output, output_exp, 25), true);
}
TEST_F(TEST_RSQRT_UT, DATA_TYPE_FLOAT16_SUCC)
{
vector<DataType> data_types = {DT_FLOAT16, DT_FLOAT16};
vector<vector<int64_t>> shapes = {{5, 5}, {5, 5}};
Eigen::half input[25];
for (int i = 0; i < 25; i++) {
input[i] = static_cast<Eigen::half>(i + 1);
}
Eigen::half output[25];
for (int i = 0; i < 25; i++) {
output[i] = static_cast<Eigen::half>(0);
}
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_OK);
Eigen::half output_exp[25];
for (int i = 0; i < 25; i++) {
output_exp[i] = static_cast<Eigen::half>(1.0f / std::sqrt(static_cast<float>(i + 1)));
}
EXPECT_EQ(CompareResult(output, output_exp, 25), true);
}
TEST_F(TEST_RSQRT_UT, DATA_TYPE_FLOAT_MULTICORE_SUCC)
{
constexpr int N = 9 * 1024;
vector<DataType> data_types = {DT_FLOAT, DT_FLOAT};
vector<vector<int64_t>> shapes = {{9, 1024}, {9, 1024}};
float *input = new float[N];
float *output = new float[N];
float *output_exp = new float[N];
for (int i = 0; i < N; i++) {
input[i] = static_cast<float>(i % 100 + 1);
output[i] = 0.0f;
output_exp[i] = 1.0f / std::sqrt(static_cast<float>(i % 100 + 1));
}
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_OK);
EXPECT_EQ(CompareResult(output, output_exp, N), true);
delete[] input;
delete[] output;
delete[] output_exp;
}
TEST_F(TEST_RSQRT_UT, DATA_TYPE_COMPLEX64_SUCC)
{
vector<DataType> data_types = {DT_COMPLEX64, DT_COMPLEX64};
vector<vector<int64_t>> shapes = {{5, 5}, {5, 5}};
std::complex<float> input[25];
for (int i = 0; i < 25; i++) {
input[i] = std::complex<float>(static_cast<float>(i + 1), 0.0f);
}
std::complex<float> output[25] = {};
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_OK);
std::complex<float> output_exp[25];
for (int i = 0; i < 25; i++) {
float a = static_cast<float>(i + 1);
output_exp[i] = std::complex<float>(1.0f / std::sqrt(a), 0.0f);
}
EXPECT_EQ(CompareResult(output, output_exp, 25), true);
}
TEST_F(TEST_RSQRT_UT, DATA_TYPE_COMPLEX128_SUCC)
{
vector<DataType> data_types = {DT_COMPLEX128, DT_COMPLEX128};
vector<vector<int64_t>> shapes = {{5, 5}, {5, 5}};
std::complex<double> input[25];
for (int i = 0; i < 25; i++) {
input[i] = std::complex<double>(static_cast<double>(i + 1), 0.0);
}
std::complex<double> output[25] = {};
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_OK);
std::complex<double> output_exp[25];
for (int i = 0; i < 25; i++) {
double a = static_cast<double>(i + 1);
output_exp[i] = std::complex<double>(1.0 / std::sqrt(a), 0.0);
}
EXPECT_EQ(CompareResult(output, output_exp, 25), true);
}
TEST_F(TEST_RSQRT_UT, DATA_TYPE_COMPLEX64_MULTICORE_SUCC)
{
constexpr int N = 5 * 1024;
vector<DataType> data_types = {DT_COMPLEX64, DT_COMPLEX64};
vector<vector<int64_t>> shapes = {{5, 1024}, {5, 1024}};
std::complex<float> *input = new std::complex<float>[N];
std::complex<float> *output = new std::complex<float>[N];
std::complex<float> *output_exp = new std::complex<float>[N];
for (int i = 0; i < N; i++) {
float a = static_cast<float>(i % 100 + 1);
input[i] = std::complex<float>(a, 0.0f);
output[i] = std::complex<float>(0.0f, 0.0f);
output_exp[i] = std::complex<float>(1.0f / std::sqrt(a), 0.0f);
}
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_OK);
EXPECT_EQ(CompareResult(output, output_exp, N), true);
delete[] input;
delete[] output;
delete[] output_exp;
}
TEST_F(TEST_RSQRT_UT, INPUT_ZERO_FLOAT_SUCCESS)
{
vector<DataType> data_types = {DT_FLOAT, DT_FLOAT};
vector<vector<int64_t>> shapes = {{2, 2}, {2, 2}};
float input[4] = {0.0f, 0.0f, 0.0f, 0.0f};
float output[4] = {0.0f};
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_OK);
}
TEST_F(TEST_RSQRT_UT, INPUT_ZERO_DOUBLE_SUCCESS)
{
vector<DataType> data_types = {DT_DOUBLE, DT_DOUBLE};
vector<vector<int64_t>> shapes = {{2, 2}, {2, 2}};
double input[4] = {0.0, 0.0, 0.0, 0.0};
double output[4] = {0.0};
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_OK);
}
TEST_F(TEST_RSQRT_UT, INPUT_ZERO_FLOAT16_SUCCESS)
{
vector<DataType> data_types = {DT_FLOAT16, DT_FLOAT16};
vector<vector<int64_t>> shapes = {{2, 2}, {2, 2}};
Eigen::half input[4] = {static_cast<Eigen::half>(0)};
Eigen::half output[4] = {static_cast<Eigen::half>(0)};
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_OK);
}
TEST_F(TEST_RSQRT_UT, INPUT_SHAPE_EXCEPTION)
{
vector<DataType> data_types = {DT_DOUBLE, DT_DOUBLE};
vector<vector<int64_t>> shapes = {{2, 10}, {2, 11}};
double input[20] = {1.0};
double output[22] = {0.0};
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_PARAM_INVALID);
}
TEST_F(TEST_RSQRT_UT, INPUT_DTYPE_EXCEPTION)
{
vector<DataType> data_types = {DT_DOUBLE, DT_FLOAT};
vector<vector<int64_t>> shapes = {{2, 11}, {2, 11}};
double input[22] = {1.0};
float output[22] = {0.0f};
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_PARAM_INVALID);
}
TEST_F(TEST_RSQRT_UT, INPUT_NULL_EXCEPTION)
{
vector<DataType> data_types = {DT_DOUBLE, DT_DOUBLE};
vector<vector<int64_t>> shapes = {{2, 11}, {2, 11}};
double output[22] = {0.0};
vector<void*> datas = {(void*)nullptr, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_PARAM_INVALID);
}
TEST_F(TEST_RSQRT_UT, OUTPUT_NULL_EXCEPTION)
{
vector<DataType> data_types = {DT_DOUBLE, DT_DOUBLE};
vector<vector<int64_t>> shapes = {{2, 11}, {2, 11}};
double input[22] = {1.0};
vector<void*> datas = {(void*)input, (void*)nullptr};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_PARAM_INVALID);
}
TEST_F(TEST_RSQRT_UT, INPUT_BOOL_UNSUPPORT)
{
vector<DataType> data_types = {DT_BOOL, DT_BOOL};
vector<vector<int64_t>> shapes = {{2, 11}, {2, 11}};
bool input[22] = {true};
bool output[22] = {false};
vector<void*> datas = {(void*)input, (void*)output};
CREATE_NODEDEF(shapes, data_types, datas);
RUN_KERNEL(node_def, HOST, KERNEL_STATUS_PARAM_INVALID);
}