* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file top_k_v2_apt.cpp
* \brief top k v2 impl
*/
#ifndef TOP_K_V2_APT_H
#define TOP_K_V2_APT_H
#include "arch35/radix_sort_top_k.h"
#include "arch35/radix_topk_constant.h"
#include "arch35/top_k_merge_sort.h"
#include "arch35/radix_sort_top_k_single_block.h"
#include "arch35/radix_sort_top_k_single_core.h"
#include "arch35/radix_sort_top_k_inter_core_template_optimization.h"
#include "arch35/sort_and_top_k_more_core.h"
using namespace AscendC;
using namespace SortAndTopK;
#define TOPK_COMMON_TILING_KEY_INT64 1004
#define TOPK_COMMON_TILING_KEY_INT32 1003
#define TOPK_COMMON_TILING_KEY_INT16 1002
#define TOPK_COMMON_TILING_KEY_INT8 1001
#define TOPK_COMMON_TILING_KEY_UINT64 2004
#define TOPK_COMMON_TILING_KEY_UINT32 2003
#define TOPK_COMMON_TILING_KEY_UINT16 2002
#define TOPK_COMMON_TILING_KEY_UINT8 2001
#define TOPK_COMMON_TILING_KEY_FLOAT 3003
#define TOPK_COMMON_TILING_KEY_FLOAT16 3002
#define TOPK_COMMON_TILING_KEY_BF16 4002
#define TOPK_MERGE_SORT_TILING_KEY_FLOAT 13003
#define TOPK_MERGE_SORT_TILING_KEY_FLOAT16 13002
#define TOPK_MERGE_SORT_TILING_KEY_BF16 14002
const uint32_t SINGLE_CORE_MODE = 1;
const uint32_t MULT_CORE_OPTIM_MODE = 4;
const uint32_t SORT_AND_TOP_K_MODE = 5;
template <typename T, typename UNSINGED_TYPE, int32_t NUM_PASS, typename T_INDEX, typename T_INDEX_TO>
__aicore__ inline void RadixSortTopKOpObject(
GM_ADDR x, GM_ADDR k, GM_ADDR values, GM_ADDR indices, GM_ADDR globalWorkGm, GM_ADDR tiling)
{
GET_TILING_DATA(tilingData, tiling);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIV_1_0);
bool isLargest = (tilingData.isLargest > 0) ? true : false;
bool isSort = (tilingData.isSort > 0) ? true : false;
TPipe tPipe;
if (isLargest) {
if (isSort) {
RadixSortTopK<T, UNSINGED_TYPE, NUM_PASS, true, true, T_INDEX, T_INDEX_TO> radixSortTopK;
radixSortTopK.Init(x, k, values, indices, globalWorkGm, &tilingData);
radixSortTopK.ProcessTopK();
} else {
RadixSortTopK<T, UNSINGED_TYPE, NUM_PASS, true, false, T_INDEX, T_INDEX_TO> radixSortTopK;
radixSortTopK.Init(x, k, values, indices, globalWorkGm, &tilingData);
radixSortTopK.ProcessTopK();
}
} else {
if (isSort) {
RadixSortTopK<T, UNSINGED_TYPE, NUM_PASS, false, true, T_INDEX, T_INDEX_TO> radixSortTopK;
radixSortTopK.Init(x, k, values, indices, globalWorkGm, &tilingData);
radixSortTopK.ProcessTopK();
} else {
RadixSortTopK<T, UNSINGED_TYPE, NUM_PASS, false, false, T_INDEX, T_INDEX_TO> radixSortTopK;
radixSortTopK.Init(x, k, values, indices, globalWorkGm, &tilingData);
radixSortTopK.ProcessTopK();
}
}
}
template <typename T, typename UNSINGED_TYPE, int32_t NUM_PASS, typename T_INDEX, typename T_INDEX_TO>
__aicore__ inline void SortAndTopKOpObject(GM_ADDR x, GM_ADDR values, GM_ADDR indices, GM_ADDR globalWorkGm,
GM_ADDR tiling)
{
GET_TILING_DATA(tilingData, tiling);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIV_1_0);
bool isLargest = (tilingData.isLargest > 0) ? true : false;
TPipe tPipe;
if (isLargest) {
SortAndTopK::SortAndTopKMoreCore<T, T_INDEX_TO, UNSINGED_TYPE, T_INDEX, 1> sortAndTopKMoreCore;
sortAndTopKMoreCore.InitParam(x, values, indices, globalWorkGm, &tilingData, &tPipe);
sortAndTopKMoreCore.ProcessTopK();
} else {
SortAndTopK::SortAndTopKMoreCore<T, T_INDEX_TO, UNSINGED_TYPE, T_INDEX, 0> sortAndTopKMoreCore;
sortAndTopKMoreCore.InitParam(x, values, indices, globalWorkGm, &tilingData, &tPipe);
sortAndTopKMoreCore.ProcessTopK();
}
}
template <typename T, typename UNSINGED_TYPE, int32_t NUM_PASS, typename T_INDEX, typename T_INDEX_TO>
__aicore__ inline void RadixSortTopKSingleCoreOpObject(
GM_ADDR x, GM_ADDR k, GM_ADDR values, GM_ADDR indices, GM_ADDR globalWorkGm, GM_ADDR tiling)
{
GET_TILING_DATA(tilingData, tiling);
bool isLargest = (tilingData.isLargest > 0) ? true : false;
bool isSort = (tilingData.isSort > 0) ? true : false;
TPipe tPipe;
if (isLargest) {
if (isSort) {
RadixSortTopKSingleCore<T, UNSINGED_TYPE, NUM_PASS, true, true, T_INDEX, T_INDEX_TO> radixSortTopKSingleCore;
radixSortTopKSingleCore.Init(x, k, values, indices, globalWorkGm, &tilingData, &tPipe);
radixSortTopKSingleCore.Process();
} else {
RadixSortTopKSingleCore<T, UNSINGED_TYPE, NUM_PASS, true, false, T_INDEX, T_INDEX_TO> radixSortTopKSingleCore;
radixSortTopKSingleCore.Init(x, k, values, indices, globalWorkGm, &tilingData, &tPipe);
radixSortTopKSingleCore.Process();
}
} else {
if (isSort) {
RadixSortTopKSingleCore<T, UNSINGED_TYPE, NUM_PASS, false, true, T_INDEX, T_INDEX_TO> radixSortTopKSingleCore;
radixSortTopKSingleCore.Init(x, k, values, indices, globalWorkGm, &tilingData, &tPipe);
radixSortTopKSingleCore.Process();
} else {
RadixSortTopKSingleCore<T, UNSINGED_TYPE, NUM_PASS, false, false, T_INDEX, T_INDEX_TO> radixSortTopKSingleCore;
radixSortTopKSingleCore.Init(x, k, values, indices, globalWorkGm, &tilingData, &tPipe);
radixSortTopKSingleCore.Process();
}
}
}
template <typename T, typename T_INDEX_TO>
__aicore__ inline void RadixSortTopKMultiCoreOpObject(
GM_ADDR x, GM_ADDR k, GM_ADDR values, GM_ADDR indices, GM_ADDR globalWorkGm, GM_ADDR tiling)
{
GET_TILING_DATA(tilingData, tiling);
bool isLargest = (tilingData.isLargest > 0) ? true : false;
bool isSort = (tilingData.isSort > 0) ? true : false;
TPipe tPipe;
if (isLargest) {
if (isSort) {
RadixSortTopKMultiCoreOptimization<T, true, true, T_INDEX_TO> radixSortTopKMultiCoreOptimization;
radixSortTopKMultiCoreOptimization.Init(x, k, values, indices, globalWorkGm, &tilingData);
radixSortTopKMultiCoreOptimization.Process();
} else {
RadixSortTopKMultiCoreOptimization<T, true, false, T_INDEX_TO> radixSortTopKMultiCoreOptimization;
radixSortTopKMultiCoreOptimization.Init(x, k, values, indices, globalWorkGm, &tilingData);
radixSortTopKMultiCoreOptimization.Process();
}
} else {
if (isSort) {
RadixSortTopKMultiCoreOptimization<T, false, true, T_INDEX_TO> radixSortTopKMultiCoreOptimization;
radixSortTopKMultiCoreOptimization.Init(x, k, values, indices, globalWorkGm, &tilingData);
radixSortTopKMultiCoreOptimization.Process();
} else {
RadixSortTopKMultiCoreOptimization<T, false, false, T_INDEX_TO> radixSortTopKMultiCoreOptimization;
radixSortTopKMultiCoreOptimization.Init(x, k, values, indices, globalWorkGm, &tilingData);
radixSortTopKMultiCoreOptimization.Process();
}
}
}
template <typename T, typename UNSINGED_TYPE, int32_t NUM_PASS, typename T_INDEX, typename T_INDEX_TO>
__aicore__ inline void RadixSortTopKSingleBlockOpObject(
GM_ADDR x, GM_ADDR k, GM_ADDR values, GM_ADDR indices, GM_ADDR globalWorkGm, GM_ADDR tiling)
{
GET_TILING_DATA(tilingData, tiling);
bool isLargest = (tilingData.isLargest > 0) ? true : false;
bool isSort = (tilingData.isSort > 0) ? true : false;
TPipe tPipe;
if (isLargest) {
if (isSort) {
RadixSortTopKSingleBlock<T, UNSINGED_TYPE, true, true, T_INDEX, T_INDEX_TO> radixSortTopKSingleBlock;
radixSortTopKSingleBlock.Init(x, k, values, indices, globalWorkGm, &tilingData, &tPipe);
radixSortTopKSingleBlock.Process();
} else {
RadixSortTopKSingleBlock<T, UNSINGED_TYPE, true, false, T_INDEX, T_INDEX_TO> radixSortTopKSingleBlock;
radixSortTopKSingleBlock.Init(x, k, values, indices, globalWorkGm, &tilingData, &tPipe);
radixSortTopKSingleBlock.Process();
}
} else {
if (isSort) {
RadixSortTopKSingleBlock<T, UNSINGED_TYPE, false, true, T_INDEX, T_INDEX_TO> radixSortTopKSingleBlock;
radixSortTopKSingleBlock.Init(x, k, values, indices, globalWorkGm, &tilingData, &tPipe);
radixSortTopKSingleBlock.Process();
} else {
RadixSortTopKSingleBlock<T, UNSINGED_TYPE, false, false, T_INDEX, T_INDEX_TO> radixSortTopKSingleBlock;
radixSortTopKSingleBlock.Init(x, k, values, indices, globalWorkGm, &tilingData, &tPipe);
radixSortTopKSingleBlock.Process();
}
}
}
template <typename T, typename UNSINGED_TYPE, int32_t NUM_PASS, typename T_INDEX_TO>
__aicore__ inline void generateOpObject(
GM_ADDR x,
GM_ADDR k,
GM_ADDR values,
GM_ADDR indices,
GM_ADDR globalWorkGm,
GM_ADDR tiling)
{
GET_TILING_DATA(tilingData, tiling);
bool isLargest = (tilingData.isLargest > 0) ? true : false;
bool isSort = (tilingData.isSort > 0) ? true : false;
bool isSingleBlock = (tilingData.lastDimNeedCore == 1) ? true : false;
bool isSingleCore = (tilingData.modeType == SINGLE_CORE_MODE) ? true : false;
bool isSortAndTopK = (tilingData.modeType == SORT_AND_TOP_K_MODE) ? true : false;
bool isInInt32Range = (tilingData.isInInt32Range > 0) ? true : false;
bool isMultiCoreOptimMode = (tilingData.modeType == MULT_CORE_OPTIM_MODE) ? true : false;
if (isSingleBlock) {
if (isInInt32Range) {
RadixSortTopKSingleBlockOpObject<T, UNSINGED_TYPE, NUM_PASS, int32_t, T_INDEX_TO>(
x, k, values, indices, globalWorkGm, tiling);
} else {
RadixSortTopKSingleBlockOpObject<T, UNSINGED_TYPE, NUM_PASS, int64_t, T_INDEX_TO>(
x, k, values, indices, globalWorkGm, tiling);
}
return;
}
if (isSortAndTopK) {
if (isInInt32Range) {
SortAndTopKOpObject<T, UNSINGED_TYPE, NUM_PASS, uint32_t, T_INDEX_TO>(
x, values, indices, globalWorkGm, tiling);
} else {
SortAndTopKOpObject<T, UNSINGED_TYPE, NUM_PASS, int64_t, T_INDEX_TO>(
x, values, indices, globalWorkGm, tiling);
}
return;
}
if (isSingleCore) {
if (isInInt32Range) {
RadixSortTopKSingleCoreOpObject<T, UNSINGED_TYPE, NUM_PASS, int32_t, T_INDEX_TO>(
x, k, values, indices, globalWorkGm, tiling);
} else {
RadixSortTopKSingleCoreOpObject<T, UNSINGED_TYPE, NUM_PASS, int64_t, T_INDEX_TO>(
x, k, values, indices, globalWorkGm, tiling);
}
return;
}
if (isMultiCoreOptimMode) {
if (isInInt32Range) {
RadixSortTopKMultiCoreOpObject<T, T_INDEX_TO>(
x, k, values, indices, globalWorkGm, tiling);
}
return;
}
if (isInInt32Range) {
RadixSortTopKOpObject<T, UNSINGED_TYPE, NUM_PASS, int32_t, T_INDEX_TO>(
x, k, values, indices, globalWorkGm, tiling);
} else {
RadixSortTopKOpObject<T, UNSINGED_TYPE, NUM_PASS, int64_t, T_INDEX_TO>(
x, k, values, indices, globalWorkGm, tiling);
}
}
template <typename T, typename CONVERT_TYPE, typename INDEX_DTYPE>
__aicore__ inline void generateMergeTopKObject(
GM_ADDR x, GM_ADDR values, GM_ADDR indices, GM_ADDR globalWorkGm, GM_ADDR tiling)
{
GET_TILING_DATA(tilingData, tiling);
bool isLargest = (tilingData.isLargest > 0) ? true : false;
TPipe pipe;
if (isLargest) {
topkV2::MergeSort<T, CONVERT_TYPE, TopKV2TilingDataSimd, true, INDEX_DTYPE> mergeSort;
mergeSort.Init(x, values, indices, globalWorkGm, &tilingData, &pipe);
mergeSort.ProcessSort();
} else {
topkV2::MergeSort<T, CONVERT_TYPE, TopKV2TilingDataSimd, false, INDEX_DTYPE> mergeSort;
mergeSort.Init(x, values, indices, globalWorkGm, &tilingData, &pipe);
mergeSort.ProcessSort();
}
}
extern "C" __global__ __aicore__ void top_k_v2(GM_ADDR x, GM_ADDR k, GM_ADDR values, GM_ADDR indices, GM_ADDR workspace, GM_ADDR tiling)
{
if (workspace == nullptr) {
return;
}
SetSysWorkspace(workspace);
GM_ADDR globalWorkGm = GetUserWorkspace(workspace);
if (globalWorkGm == nullptr) {
return;
}
#if ORIG_DTYPE_X == DT_INT64
TILING_KEY_IS(TOPK_COMMON_TILING_KEY_INT64);
#if TILING_KEY_VAR == TOPK_COMMON_TILING_KEY_INT64
generateOpObject<int64_t, uint64_t, topkV2::B64_BITE_SIZE, DTYPE_INDICES>(x, k, values, indices, globalWorkGm, tiling);
#endif
#endif
#if ORIG_DTYPE_X == DT_INT32
TILING_KEY_IS(TOPK_COMMON_TILING_KEY_INT32);
#if TILING_KEY_VAR == TOPK_COMMON_TILING_KEY_INT32
generateOpObject<int32_t, uint32_t, topkV2::B32_BITE_SIZE, DTYPE_INDICES>(x, k, values, indices, globalWorkGm, tiling);
#endif
#endif
#if ORIG_DTYPE_X == DT_INT16
TILING_KEY_IS(TOPK_COMMON_TILING_KEY_INT16);
#if TILING_KEY_VAR == TOPK_COMMON_TILING_KEY_INT16
generateOpObject<int16_t, uint16_t, topkV2::B16_BITE_SIZE, DTYPE_INDICES>(x, k, values, indices, globalWorkGm, tiling);
#endif
#endif
#if ORIG_DTYPE_X == DT_INT8
TILING_KEY_IS(TOPK_COMMON_TILING_KEY_INT8);
#if TILING_KEY_VAR == TOPK_COMMON_TILING_KEY_INT8
generateOpObject<int8_t, uint8_t, topkV2::B8_BITE_SIZE, DTYPE_INDICES>(x, k, values, indices, globalWorkGm, tiling);
#endif
#endif
#if ORIG_DTYPE_X == DT_UINT64
TILING_KEY_IS(TOPK_COMMON_TILING_KEY_UINT64);
#if TILING_KEY_VAR == TOPK_COMMON_TILING_KEY_UINT64
generateOpObject<uint64_t, uint64_t, topkV2::B64_BITE_SIZE, DTYPE_INDICES>(x, k, values, indices, globalWorkGm, tiling);
#endif
#endif
#if ORIG_DTYPE_X == DT_UINT32
TILING_KEY_IS(TOPK_COMMON_TILING_KEY_UINT32);
#if TILING_KEY_VAR == TOPK_COMMON_TILING_KEY_UINT32
generateOpObject<uint32_t, uint32_t, topkV2::B32_BITE_SIZE, DTYPE_INDICES>(x, k, values, indices, globalWorkGm, tiling);
#endif
#endif
#if ORIG_DTYPE_X == DT_UINT16
TILING_KEY_IS(TOPK_COMMON_TILING_KEY_UINT16);
#if TILING_KEY_VAR == TOPK_COMMON_TILING_KEY_UINT16
generateOpObject<uint16_t, uint16_t, topkV2::B16_BITE_SIZE, DTYPE_INDICES>(x, k, values, indices, globalWorkGm, tiling);
#endif
#endif
#if ORIG_DTYPE_X == DT_UINT8
TILING_KEY_IS(TOPK_COMMON_TILING_KEY_UINT8);
#if TILING_KEY_VAR == TOPK_COMMON_TILING_KEY_UINT8
generateOpObject<uint8_t, uint8_t, topkV2::B8_BITE_SIZE, DTYPE_INDICES>(x, k, values, indices, globalWorkGm, tiling);
#endif
#endif
#if ORIG_DTYPE_X == DT_FLOAT
TILING_KEY_IS(TOPK_COMMON_TILING_KEY_FLOAT);
TILING_KEY_IS(TOPK_MERGE_SORT_TILING_KEY_FLOAT);
#if TILING_KEY_VAR == TOPK_COMMON_TILING_KEY_FLOAT
generateOpObject<float, uint32_t, topkV2::B32_BITE_SIZE, DTYPE_INDICES>(x, k, values, indices, globalWorkGm, tiling);
#elif TILING_KEY_VAR == TOPK_MERGE_SORT_TILING_KEY_FLOAT
generateMergeTopKObject<float, float, DTYPE_INDICES>(x, values, indices, globalWorkGm, tiling);
#endif
#endif
#if ORIG_DTYPE_X == DT_FLOAT16
TILING_KEY_IS(TOPK_COMMON_TILING_KEY_FLOAT16);
TILING_KEY_IS(TOPK_MERGE_SORT_TILING_KEY_FLOAT16);
#if TILING_KEY_VAR == TOPK_COMMON_TILING_KEY_FLOAT16
generateOpObject<half, uint16_t, topkV2::B16_BITE_SIZE, DTYPE_INDICES>(x, k, values, indices, globalWorkGm, tiling);
#elif TILING_KEY_VAR == TOPK_MERGE_SORT_TILING_KEY_FLOAT16
generateMergeTopKObject<half, half, DTYPE_INDICES>(x, values, indices, globalWorkGm, tiling);
#endif
#endif
#if ORIG_DTYPE_X == DT_BF16
TILING_KEY_IS(TOPK_COMMON_TILING_KEY_BF16);
TILING_KEY_IS(TOPK_MERGE_SORT_TILING_KEY_BF16);
#if TILING_KEY_VAR == TOPK_COMMON_TILING_KEY_BF16
generateOpObject<bfloat16_t, uint16_t, topkV2::B16_BITE_SIZE, DTYPE_INDICES>(x, k, values, indices, globalWorkGm, tiling);
#elif TILING_KEY_VAR == TOPK_MERGE_SORT_TILING_KEY_BF16
generateMergeTopKObject<bfloat16_t, float, DTYPE_INDICES>(x, values, indices, globalWorkGm, tiling);
#endif
#endif
}
#endif