* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_topk.cpp
* \brief
*/
#include "test_suite_stest_ops.h"
#include "test_dev_func_runner.h"
using namespace npu::tile_fwk;
class TopkOnBoardTest : public npu::tile_fwk::stest::TestSuite_STest_Ops_Aihac {};
struct TopKParams {
int32_t shape0;
int32_t shape1;
int32_t k;
bool isLargest;
};
void TopKOnBoardFunc(TopKParams& params)
{
AclInit(nullptr);
RuntimeSetDevice(GetDeviceIdByEnvVar());
int32_t shape0 = params.shape0;
int32_t shape1 = params.shape1;
int32_t k = params.k;
bool isLargest = params.isLargest;
uint64_t inputSize = shape0 * shape1 * sizeof(float);
uint64_t outputSize = shape0 * k * sizeof(float);
uint8_t* out_ptr = allocDevAddr(outputSize);
uint8_t* out_ptr1 = allocDevAddr(outputSize);
PROGRAM("TOPK")
{
std::vector<int64_t> input_shape = {shape0, shape1};
std::vector<int64_t> output_shape = {shape0, k};
void* x_ptr = readToDev(GetGoldenDir() + "/x.bin", inputSize);
TileShape::Current().SetVecTile({shape0, shape1});
Tensor input_a(DataType::DT_FP32, input_shape, (uint8_t*)x_ptr, "A");
auto output = std::make_tuple(
Tensor(DataType::DT_FP32, output_shape, out_ptr, "npu_val"),
Tensor(DataType::DT_FP32, output_shape, out_ptr1, "resDics"));
config::SetBuildStatic(true);
FUNCTION("TOPK_T", {input_a, std::get<0>(output), std::get<1>(output)})
{
output = TopK(input_a, k, -1, isLargest);
}
}
DevFuncRunner::Run(Program::GetInstance().GetLastFunction());
std::vector<float> golden_val(shape0 * k);
std::vector<int32_t> golden_idx(shape0 * k);
std::vector<float> dev_val(shape0 * k);
std::vector<int32_t> dev_idx(shape0 * k);
CopyFromTensor((uint8_t*)dev_val.data(), (uint8_t*)out_ptr, outputSize);
CopyFromTensor((uint8_t*)dev_idx.data(), (uint8_t*)out_ptr1, outputSize);
readInput(GetGoldenDir() + "/val.bin", golden_val);
readInput(GetGoldenDir() + "/idx.bin", golden_idx);
int ret_val = resultCmp(golden_val, dev_val, 0.001f);
int ret_idx = resultCmp(golden_idx, dev_idx, 0);
EXPECT_EQ(ret_val, true);
EXPECT_EQ(ret_idx, true);
}
TEST_F(TopkOnBoardTest, test_operation_tensor_128_32_32_topk)
{
TopKParams params;
params.shape0 = 128;
params.shape1 = 32;
params.k = 32;
params.isLargest = true;
TopKOnBoardFunc(params);
}
TEST_F(TopkOnBoardTest, test_operation_tensor_128_32_16_topk)
{
TopKParams params;
params.shape0 = 128;
params.shape1 = 32;
params.k = 16;
params.isLargest = true;
TopKOnBoardFunc(params);
}
TEST_F(TopkOnBoardTest, test_operation_tensor_4_32_8_topk)
{
TopKParams params;
params.shape0 = 4;
params.shape1 = 32;
params.k = 8;
params.isLargest = true;
TopKOnBoardFunc(params);
}
TEST_F(TopkOnBoardTest, test_operation_tensor_2_16_8_topk)
{
TopKParams params;
params.shape0 = 2;
params.shape1 = 16;
params.k = 8;
params.isLargest = true;
TopKOnBoardFunc(params);
}
TEST_F(TopkOnBoardTest, test_operation_tensor_2_8_4_topk)
{
TopKParams params;
params.shape0 = 2;
params.shape1 = 8;
params.k = 4;
params.isLargest = true;
TopKOnBoardFunc(params);
}
TEST_F(TopkOnBoardTest, test_operation_tensor_1_8_4_topk)
{
TopKParams params;
params.shape0 = 1;
params.shape1 = 8;
params.k = 4;
params.isLargest = true;
TopKOnBoardFunc(params);
}
TEST_F(TopkOnBoardTest, test_operation_tensor_2_288_15_topk)
{
TopKParams params;
params.shape0 = 2;
params.shape1 = 288;
params.k = 15;
params.isLargest = true;
TopKOnBoardFunc(params);
}
TEST_F(TopkOnBoardTest, test_operation_tensor_2_288_15_topk_reverse)
{
TopKParams params;
params.shape0 = 2;
params.shape1 = 288;
params.k = 15;
params.isLargest = false;
TopKOnBoardFunc(params);
}