* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file topk.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_TOPK_H__
#endif
#ifndef LIB_SORT_TOPK_H
#define LIB_SORT_TOPK_H
#include "include/adv_api/sort/topk_utils.h"
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201 || __NPU_ARCH__ == 2002 || __NPU_ARCH__ == 3510 || __NPU_ARCH__ == 5102 || \
__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113)
#include "kernel_tensor.h"
#include "kernel_utils.h"
#include "kernel_tiling/kernel_tiling.h"
#include "../../../impl/adv_api/detail/sort/topk/topk_common_utils.h"
#ifdef ASCENDC_CPU_DEBUG
#include "../../../impl/adv_api/detail/api_check/kernel_check/sort/topk/topk_check.h"
#endif
#include "../../../impl/adv_api/detail/api_check/kernel_api_check.h"
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 2201 || __NPU_ARCH__ == 2002)
#include "../../../impl/adv_api/detail/sort/topk/topk_common_impl.h"
#endif
#if defined(__NPU_ARCH__) && __NPU_ARCH__ == 2201
#include "../../../impl/adv_api/detail/sort/topk/topk_v220_impl.h"
#elif defined(__NPU_ARCH__) && __NPU_ARCH__ == 2002
#include "../../../impl/adv_api/detail/sort/topk/topk_v200_impl.h"
#elif defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3510 || __NPU_ARCH__ == 5102 || __NPU_ARCH__ == 3003 || \
__NPU_ARCH__ == 3113)
#include "../../../impl/adv_api/detail/sort/topk/topk_c310_impl.h"
#endif
namespace AscendC {
#pragma begin_pipe(V)
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3510 || __NPU_ARCH__ == 5102 || __NPU_ARCH__ == 3003 || \
__NPU_ARCH__ == 3113)
* @ingroup TopK
* @brief Get the top k maximum or minimum values and their corresponding indices of the last dimension.
* @tparam T: Data type to be sorted, half or float.
* @tparam isInitIndex: Whether to transfer the index of the input data.
If the value is true, srcIndexLocal is the index of the input data.
If the value is false, the index is generated by the Topk API.
* @tparam isHasfinish: The isHasfinish parameter is used to specify that the sorting of some rows is invalid.
If the value is true, enable the function. If the value is false, disable the function.
In normal mode, isHasfinish can be set to true or false.
In small mode, isHasfinish can only be set to false.
* @tparam isReuseSrc: Whether temporary variables can reuse the input memory.
This parameter is reserved. Use the default value false.
* @tparam topkMode: Normal mode or small mode,
Small mode is recommended when the inner axis length is 32. Performance will be high.
* @param [out] dstValueLocal: Used to store k sorted values.
* @param [out] dstIndexLocal: Used to store indexes corresponding to sorted k values.
* @param [in] srcLocal: Input data to hold values to be sorted.
* @param [in] srcIndexLocal: The input data is used to store the index corresponding to the value of srcLocal.
* @param [in] finishLocal: Used to specify that the sort of some rows is an invalid sort with shape of (outter, 1).
* @param [in] tmpLocal: Temporary space for storing intermediate variables during internal calculation.
* @param [in] k: Obtain the first k maximum or minimum values and their corresponding indexes.
* @param [in] tilling: Tiling information required for TopK calculation.
* @param [in] topKInfo: Shape information of srcLocal.
* @param [in] isLargest: Descending or ascending order. The value true indicates descending order,
and the value false indicates ascending order.
*/
template <typename T, bool isInitIndex = false, bool isHasfinish = false, bool isReuseSrc = false,
enum TopKMode topkMode = TopKMode::TOPK_NORMAL, const TopKConfig& config = defaultTopKConfig>
__aicore__ inline void TopK(const LocalTensor<T>& dstValueLocal, const LocalTensor<int32_t>& dstIndexLocal,
const LocalTensor<T>& srcLocal, const LocalTensor<int32_t>& srcIndexLocal, const LocalTensor<bool>& finishLocal,
const LocalTensor<uint8_t>& tmpLocal, const int32_t k, const TopkTiling& tilling, const TopKInfo& topKInfo,
const bool isLargest = true)
{
if ASCEND_IS_AIC {
return;
}
#if ASCENDC_CPU_DEBUG
TopkInputCheck<T, isInitIndex, topkMode, config>(k, topKInfo);
#endif
LocalTensor<T> tempBuffer = tmpLocal.template ReinterpretCast<T>();
if constexpr (config.algo == TopKAlgo::RADIX_SELECT) {
static_assert((SupportType<T, uint8_t, int8_t, uint16_t, int16_t, half, float, bfloat16_t, uint32_t, int32_t,
uint64_t, int64_t>()),
"Type must be uint8_t/int8_t/uint16_t/int16_t/half/float/bfloat16_t/uint32_t/int32_t/uint64_t/int64_t in "
"topk radix select algorithm.");
static_assert((!isHasfinish), "Topk radix select algorithm cannot support to set finish flag.");
if constexpr (topkMode == TopKMode::TOPK_NORMAL) {
Reg::RadixSelectTopK::TopKNormal<T, isInitIndex, isHasfinish, isReuseSrc, config>(dstValueLocal,
dstIndexLocal, srcLocal, srcIndexLocal, finishLocal, tempBuffer, k, tilling, topKInfo, isLargest);
}
if constexpr (topkMode == TopKMode::TOPK_NSMALL) {
Reg::RadixSelectTopK::TopKNSmall<T, isInitIndex, isHasfinish, isReuseSrc, config>(dstValueLocal,
dstIndexLocal, srcLocal, srcIndexLocal, finishLocal, tempBuffer, k, tilling, topKInfo, isLargest);
}
return;
}
if constexpr (config.algo == TopKAlgo::MERGE_SORT) {
static_assert((SupportType<T, half, float>()), "Type must be half/float in topk merge select algorithm.");
if constexpr (topkMode == TopKMode::TOPK_NORMAL) {
TopKNormal<T, isInitIndex, isHasfinish, isReuseSrc>(dstValueLocal, dstIndexLocal, srcLocal,
srcIndexLocal, finishLocal, tempBuffer, k, tilling, topKInfo, isLargest);
}
if constexpr (topkMode == TopKMode::TOPK_NSMALL) {
TopKNSmall<T, isInitIndex, isHasfinish, isReuseSrc>(dstValueLocal, dstIndexLocal, srcLocal,
srcIndexLocal, finishLocal, tempBuffer, k, tilling, topKInfo, isLargest);
}
}
}
* @ingroup TopK
* @brief Get the top k maximum or minimum values and their corresponding indices of the last dimension.
* @tparam T: Data type to be sorted, half or float.
* @tparam isInitIndex: Whether to transfer the index of the input data.
If the value is true, srcIndexLocal is the index of the input data.
If the value is false, the index is generated by the Topk API.
* @tparam isHasfinish: The isHasfinish parameter is used to specify that the sorting of some rows is invalid.
If the value is true, enable the function. If the value is false, disable the function.
In normal mode, isHasfinish can be set to true or false.
In small mode, isHasfinish can only be set to false.
* @tparam isReuseSrc: Whether temporary variables can reuse the input memory.
This parameter is reserved. Use the default value false.
* @tparam topkMode: Normal mode or small mode,
Small mode is recommended when the inner axis length is 32. Performance will be high.
* @param [out] dstValueLocal: Used to store k sorted values.
* @param [out] dstIndexLocal: Used to store indexes corresponding to sorted k values.
* @param [in] srcLocal: Input data to hold values to be sorted.
* @param [in] srcIndexLocal: The input data is used to store the index corresponding to the value of srcLocal.
* @param [in] finishLocal: Used to specify that the sort of some rows is an invalid sort with shape of (outter, 1).
* @param [in] k: Obtain the first k maximum or minimum values and their corresponding indexes.
* @param [in] tilling: Tiling information required for TopK calculation.
* @param [in] topKInfo: Shape information of srcLocal.
* @param [in] isLargest: Descending or ascending order. The value true indicates descending order,
and the value false indicates ascending order.
*/
template <typename T, bool isInitIndex = false, bool isHasfinish = false, bool isReuseSrc = false,
enum TopKMode topkMode = TopKMode::TOPK_NORMAL, const TopKConfig& config = defaultTopKConfig>
__aicore__ inline void TopK(const LocalTensor<T>& dstValueLocal, const LocalTensor<int32_t>& dstIndexLocal,
const LocalTensor<T>& srcLocal, const LocalTensor<int32_t>& srcIndexLocal, const LocalTensor<bool>& finishLocal,
const int32_t k, const TopkTiling& tilling, const TopKInfo& topKInfo, const bool isLargest = true)
{
if ASCEND_IS_AIC {
return;
}
LocalTensor<T> stackTensor;
PopStackBuffer<T, TPosition::LCM>(stackTensor);
#if ASCENDC_CPU_DEBUG
auto stackTensorSize = stackTensor.GetSize();
bool ans = stackTensorSize >= tilling.tmpLocalSize;
ASCENDC_ASSERT(ans, {
KERNEL_LOG(KERNEL_ERROR,
"The pop stack buffer is insufficient, topk api need %d, but only %d exists.",
tilling.tmpLocalSize,
stackTensorSize);
});
TopkInputCheck<T, isInitIndex, topkMode, config>(k, topKInfo);
#endif
stackTensor.SetSize(tilling.tmpLocalSize);
if constexpr (config.algo == TopKAlgo::RADIX_SELECT) {
static_assert((SupportType<T, uint8_t, int8_t, uint16_t, int16_t, half, float, bfloat16_t, uint32_t, int32_t,
uint64_t, int64_t>()),
"Type must be uint8_t/int8_t/uint16_t/int16_t/half/float/bfloat16_t/uint32_t/int32_t/uint64_t/int64_t in "
"topk radix select algorithm.");
static_assert((!isHasfinish), "Topk radix select algorithm cannot support to set finish flag.");
if constexpr (topkMode == TopKMode::TOPK_NORMAL) {
Reg::RadixSelectTopK::TopKNormal<T, isInitIndex, isHasfinish, isReuseSrc, config>(dstValueLocal,
dstIndexLocal, srcLocal, srcIndexLocal, finishLocal, stackTensor, k, tilling, topKInfo, isLargest);
}
if constexpr (topkMode == TopKMode::TOPK_NSMALL) {
Reg::RadixSelectTopK::TopKNSmall<T, isInitIndex, isHasfinish, isReuseSrc, config>(dstValueLocal,
dstIndexLocal, srcLocal, srcIndexLocal, finishLocal, stackTensor, k, tilling, topKInfo, isLargest);
}
return;
}
if constexpr (config.algo == TopKAlgo::MERGE_SORT) {
static_assert((SupportType<T, half, float>()), "Type must be half/float in topk merge select algorithm.");
if constexpr (topkMode == TopKMode::TOPK_NORMAL) {
TopKNormal<T, isInitIndex, isHasfinish, isReuseSrc>(
dstValueLocal, dstIndexLocal, srcLocal, srcIndexLocal, finishLocal, stackTensor, k, tilling,
topKInfo, isLargest);
}
if constexpr (topkMode == TopKMode::TOPK_NSMALL) {
TopKNSmall<T, isInitIndex, isHasfinish, isReuseSrc>(
dstValueLocal, dstIndexLocal, srcLocal, srcIndexLocal, finishLocal, stackTensor, k, tilling,
topKInfo, isLargest);
}
}
}
#else
* @ingroup TopK
* @brief Get the top k maximum or minimum values and their corresponding indices of the last dimension.
* @tparam T: Data type to be sorted, half or float.
* @tparam isInitIndex: Whether to transfer the index of the input data.
If the value is true, srcIndexLocal is the index of the input data.
If the value is false, the index is generated by the Topk API.
* @tparam isHasfinish: The isHasfinish parameter is used to specify that the sorting of some rows is invalid.
If the value is true, enable the function. If the value is false, disable the function.
In normal mode, isHasfinish can be set to true or false.
In small mode, isHasfinish can only be set to false.
* @tparam isReuseSrc: Whether temporary variables can reuse the input memory.
This parameter is reserved. Use the default value false.
* @tparam topkMode: Normal mode or small mode,
Small mode is recommended when the inner axis length is 32. Performance will be high.
* @param [out] dstValueLocal: Used to store k sorted values.
* @param [out] dstIndexLocal: Used to store indexes corresponding to sorted k values.
* @param [in] srcLocal: Input data to hold values to be sorted.
* @param [in] srcIndexLocal: The input data is used to store the index corresponding to the value of srcLocal.
* @param [in] finishLocal: Used to specify that the sort of some rows is an invalid sort with shape of (outter, 1).
* @param [in] tmpLocal: Temporary space for storing intermediate variables during internal calculation.
* @param [in] k: Obtain the first k maximum or minimum values and their corresponding indexes.
* @param [in] tilling: Tiling information required for TopK calculation.
* @param [in] topKInfo: Shape information of srcLocal.
* @param [in] isLargest: Descending or ascending order. The value true indicates descending order,
and the value false indicates ascending order.
*/
template <typename T, bool isInitIndex = false, bool isHasfinish = false, bool isReuseSrc = false,
enum TopKMode topkMode = TopKMode::TOPK_NORMAL>
__aicore__ inline void TopK(const LocalTensor<T>& dstValueLocal, const LocalTensor<int32_t>& dstIndexLocal,
const LocalTensor<T>& srcLocal, const LocalTensor<int32_t>& srcIndexLocal, const LocalTensor<bool>& finishLocal,
const LocalTensor<uint8_t>& tmpLocal, const int32_t k, const TopkTiling& tilling, const TopKInfo& topKInfo,
const bool isLargest = true)
{
if ASCEND_IS_AIC {
return;
}
CHECK_FUNC_HIGHLEVEL_API(TopK, (T, isInitIndex, isHasfinish, isReuseSrc, topkMode), (
dstValueLocal, dstIndexLocal, srcLocal, srcIndexLocal, finishLocal, tmpLocal, k, tilling, topKInfo, isLargest));
if constexpr (topkMode == TopKMode::TOPK_NORMAL) {
TopKNormal<T, isInitIndex, isHasfinish, isReuseSrc>(dstValueLocal, dstIndexLocal, srcLocal, srcIndexLocal,
finishLocal, tmpLocal, k, tilling, topKInfo, isLargest);
}
if constexpr (topkMode == TopKMode::TOPK_NSMALL) {
TopKNSmall<T, isInitIndex, isHasfinish, isReuseSrc>(dstValueLocal, dstIndexLocal, srcLocal, srcIndexLocal,
finishLocal, tmpLocal, k, tilling, topKInfo, isLargest);
}
}
* @ingroup TopK
* @brief Get the top k maximum or minimum values and their corresponding indices of the last dimension.
* @tparam T: Data type to be sorted, half or float.
* @tparam isInitIndex: Whether to transfer the index of the input data.
If the value is true, srcIndexLocal is the index of the input data.
If the value is false, the index is generated by the Topk API.
* @tparam isHasfinish: The isHasfinish parameter is used to specify that the sorting of some rows is invalid.
If the value is true, enable the function. If the value is false, disable the function.
In normal mode, isHasfinish can be set to true or false.
In small mode, isHasfinish can only be set to false.
* @tparam isReuseSrc: Whether temporary variables can reuse the input memory.
This parameter is reserved. Use the default value false.
* @tparam topkMode: Normal mode or small mode,
Small mode is recommended when the inner axis length is 32. Performance will be high.
* @param [out] dstValueLocal: Used to store k sorted values.
* @param [out] dstIndexLocal: Used to store indexes corresponding to sorted k values.
* @param [in] srcLocal: Input data to hold values to be sorted.
* @param [in] srcIndexLocal: The input data is used to store the index corresponding to the value of srcLocal.
* @param [in] finishLocal: Used to specify that the sort of some rows is an invalid sort with shape of (outter, 1).
* @param [in] k: Obtain the first k maximum or minimum values and their corresponding indexes.
* @param [in] tilling: Tiling information required for TopK calculation.
* @param [in] topKInfo: Shape information of srcLocal.
* @param [in] isLargest: Descending or ascending order. The value true indicates descending order,
and the value false indicates ascending order.
*/
template <typename T, bool isInitIndex = false, bool isHasfinish = false, bool isReuseSrc = false,
enum TopKMode topkMode = TopKMode::TOPK_NORMAL>
__aicore__ inline void TopK(const LocalTensor<T>& dstValueLocal, const LocalTensor<int32_t>& dstIndexLocal,
const LocalTensor<T>& srcLocal, const LocalTensor<int32_t>& srcIndexLocal, const LocalTensor<bool>& finishLocal,
const int32_t k, const TopkTiling& tilling, const TopKInfo& topKInfo, const bool isLargest = true)
{
if ASCEND_IS_AIC {
return;
}
CHECK_FUNC_HIGHLEVEL_API(TopK, (T, isInitIndex, isHasfinish, isReuseSrc, topkMode), (
dstValueLocal, dstIndexLocal, srcLocal, srcIndexLocal, finishLocal, k, tilling, topKInfo, isLargest));
if constexpr (topkMode == TopKMode::TOPK_NORMAL) {
TopKNormal<T, isInitIndex, isHasfinish, isReuseSrc>(
dstValueLocal, dstIndexLocal, srcLocal, srcIndexLocal, finishLocal, k, tilling, topKInfo, isLargest);
}
if constexpr (topkMode == TopKMode::TOPK_NSMALL) {
TopKNSmall<T, isInitIndex, isHasfinish, isReuseSrc>(
dstValueLocal, dstIndexLocal, srcLocal, srcIndexLocal, finishLocal, k, tilling, topKInfo, isLargest);
}
}
#endif
#pragma end_pipe
}
#endif
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_TOPK_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_TOPK_H__
#endif