* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "topk.h"
#include "math/sort_with_index/op_api/sort_with_index.h"
#include "aclnn_kernels/cast.h"
#include "op_api/aclnn_check.h"
#include "opdev/aicpu/aicpu_task.h"
#include "opdev/make_op_executor.h"
#include "opdev/op_def.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/shape_utils.h"
#include "opdev/platform.h"
using namespace op;
namespace l0op {
OP_TYPE_REGISTER(TopKV2);
OP_TYPE_REGISTER(TopKV3);
OP_TYPE_REGISTER(TopK);
const int64_t MAX_AICORE_CALC_INPUTSIZE = 32768;
const int64_t MAX_AICORE_CALC_DIM = 8;
const int64_t MAX_AICORE_CALC_REG_BASE_INT64_DIM = 22;
const int64_t MAX_AICORE_CALC_REG_BASE_INT64_INPUTSIZE = 25000000;
const int64_t MIN_AICORE_CALC_REG_BASE_INT64_INPUTSIZE = 65000;
const int64_t MAX_K = 16;
constexpr int64_t TWO_THOUSAND = 2000;
constexpr int64_t SORT_AND_TOP_K_THRESHOLD = 10000000;
static const std::initializer_list<op::DataType> ANCIENT_DTYPE_SUPPORT_LIST = {
op::DataType::DT_FLOAT16, op::DataType::DT_FLOAT};
static const std::initializer_list<op::DataType> CURRENT_DTYPE_SUPPORT_LIST = {
op::DataType::DT_FLOAT16, op::DataType::DT_FLOAT, op::DataType::DT_BF16};
static const std::initializer_list<op::DataType> FUTURE_DTYPE_SUPPORT_LIST = {
op::DataType::DT_INT64, op::DataType::DT_INT32, op::DataType::DT_INT16, op::DataType::DT_INT8,
op::DataType::DT_UINT64, op::DataType::DT_UINT32, op::DataType::DT_UINT16, op::DataType::DT_UINT8,
op::DataType::DT_BF16, op::DataType::DT_FLOAT16, op::DataType::DT_FLOAT};
static bool IsFloatTypeSoc(SocVersion version)
{
return (version >= SocVersion::ASCEND910B && version <= SocVersion::ASCEND910E) ||
(version >= SocVersion::ASCEND310P && version <= SocVersion::ASCEND310C) ||
(version == SocVersion::ASCEND610LITE);
}
static bool IsAiCoreSupport(const aclTensor* self, int64_t k)
{
SocVersion version = GetCurrentPlatformInfo().GetSocVersion();
if (version == SocVersion::ASCEND910) {
auto inputShape = self->GetViewShape();
int64_t tmpDim = static_cast<int64_t>(inputShape.GetDimNum());
int64_t inputSize = 1;
for (int64_t i = 0; i < tmpDim; i++) {
inputSize *= inputShape.GetDim(i);
}
if (inputSize > MAX_AICORE_CALC_INPUTSIZE && k < MAX_AICORE_CALC_DIM) {
return false;
}
}
if (version == SocVersion::ASCEND310P) {
OP_LOGW("l0op::TopK use ANCIENT_DTYPE_SUPPORT_LIST for socVerison[%d]", static_cast<int32_t>(version));
return CheckType(self->GetDataType(), ANCIENT_DTYPE_SUPPORT_LIST);
} else if (version == SocVersion::ASCEND910B || version == SocVersion::ASCEND910_93) {
OP_LOGW("l0op::TopK use CURRENT_DTYPE_SUPPORT_LIST for socVerison[%d]", static_cast<int32_t>(version));
return CheckType(self->GetDataType(), CURRENT_DTYPE_SUPPORT_LIST);
} else if (IsRegBase()) {
if (self->GetDataType() == op::DataType::DT_INT64) {
auto inputShape = self->GetViewShape();
int64_t tmpDim = static_cast<int64_t>(inputShape.GetDimNum());
int64_t inputSize = 1;
for (int64_t i = 0; i < tmpDim; i++) {
inputSize *= inputShape.GetDim(i);
}
if (inputSize < MAX_AICORE_CALC_REG_BASE_INT64_INPUTSIZE &&
inputSize > MIN_AICORE_CALC_REG_BASE_INT64_INPUTSIZE &&
inputShape.GetDim(tmpDim - 1) < MAX_AICORE_CALC_REG_BASE_INT64_DIM) {
return false;
}
}
OP_LOGW("l0op::TopK use FUTURE_DTYPE_SUPPORT_LIST for socVerison[%d]", static_cast<int32_t>(version));
return CheckType(self->GetDataType(), FUTURE_DTYPE_SUPPORT_LIST);
} else {
}
if (IsFloatTypeSoc(version)) {
return CheckType(
self->GetDataType(), {op::DataType::DT_FLOAT16, op::DataType::DT_FLOAT, op::DataType::DT_BF16});
}
return self->GetDataType() == op::DataType::DT_FLOAT16;
}
static bool IsAscendCSupport(const aclTensor* self, int64_t k)
{
SocVersion version = GetCurrentPlatformInfo().GetSocVersion();
if (version == SocVersion::ASCEND310P && CheckType(self->GetDataType(), {op::DataType::DT_FLOAT16}) && k < MAX_K) {
return true;
}
return false;
}
std::tuple<aclTensor*, aclTensor*> TopkV2AiCore(
const aclTensor* self, const aclTensor* k, int64_t dim, bool largest, bool sorted, aclTensor* values,
aclTensor* indices, aclOpExecutor* executor)
{
L0_DFX(TopkV2AiCore, self, k, dim, largest, sorted, values, indices);
ADD_TO_LAUNCHER_LIST_AICORE(TopKV2, OP_INPUT(self, k), OP_OUTPUT(values, indices), OP_ATTR(sorted, dim, largest));
return std::tuple<aclTensor*, aclTensor*>(values, indices);
}
std::tuple<aclTensor*, aclTensor*> TopkV2AiCoreForDavid(
const aclTensor* self, const aclTensor* k, int64_t dim, bool largest, bool sorted, aclTensor* values,
aclTensor* indices, op::DataType indicesDType, aclOpExecutor* executor)
{
L0_DFX(TopkV2AiCoreForDavid, self, k, dim, largest, sorted, values, indices, indicesDType);
ADD_TO_LAUNCHER_LIST_AICORE(
TopKV2, OP_INPUT(self, k), OP_OUTPUT(values, indices), OP_ATTR(sorted, dim, largest, indicesDType));
return std::tuple<aclTensor*, aclTensor*>(values, indices);
}
std::tuple<aclTensor*, aclTensor*> TopkV3(
const aclTensor* self, const aclTensor* k, int64_t dim, bool largest, bool sorted, aclTensor* values,
aclTensor* indices, aclOpExecutor* executor)
{
L0_DFX(TopkV3, self, k, dim, largest, sorted, values, indices);
ADD_TO_LAUNCHER_LIST_AICORE(TopKV3, OP_INPUT(self, k), OP_OUTPUT(values, indices), OP_ATTR(sorted, dim, largest));
return std::tuple<aclTensor*, aclTensor*>(values, indices);
}
static bool IsSortWithIndex(const aclTensor* self, int64_t k, bool sorted) {
auto inputShape = self->GetViewShape();
int64_t dimNum = static_cast<int64_t>(inputShape.GetDimNum());
return (k > TWO_THOUSAND) && (inputShape.GetDim(dimNum - 1) <= SORT_AND_TOP_K_THRESHOLD) && sorted;
}
std::tuple<aclTensor*, aclTensor*> TopkAiCpu(
const aclTensor* self, const aclTensor* k, int64_t dim, bool largest, bool sorted, aclTensor* values,
aclTensor* indices, aclOpExecutor* executor)
{
L0_DFX(TopkAiCpu, self, k, dim, largest, sorted, values, indices);
static internal::AicpuTaskSpace space("TopK");
auto ret = ADD_TO_LAUNCHER_LIST_AICPU(
TopK, OP_ATTR_NAMES({"sorted", "largest", "dim"}), OP_INPUT(self, k), OP_OUTPUT(values, indices),
OP_ATTR(sorted, largest, dim));
if (ret != ACLNN_SUCCESS) {
return std::tuple<aclTensor*, aclTensor*>(nullptr, nullptr);
}
return std::tuple<aclTensor*, aclTensor*>(values, indices);
}
std::tuple<aclTensor*, aclTensor*> Topk(
const aclTensor* self, int64_t k, int64_t dim, bool largest, bool sorted, op::DataType indicesDType,
aclOpExecutor* executor)
{
op::Shape outShape = self->GetStorageShape();
outShape.SetDim(dim, k);
const aclScalar* kScalar = executor->AllocScalar(k);
if (kScalar == nullptr) {
return std::tuple<aclTensor*, aclTensor*>(nullptr, nullptr);
}
const aclTensor* kTensor = executor->ConvertToTensor(kScalar, op::ToOpDataType(ACL_INT32));
auto valuesOut = executor->AllocTensor(outShape, self->GetDataType(), self->GetStorageFormat());
aclTensor* indicesOut = nullptr;
if (IsRegBase() && !IsSortWithIndex(self, k, sorted)) {
indicesOut = executor->AllocTensor(outShape, indicesDType, self->GetStorageFormat());
} else {
indicesOut = executor->AllocTensor(outShape, op::DataType::DT_INT32, self->GetStorageFormat());
}
if (IsAiCoreSupport(self, k)) {
if (IsAscendCSupport(self, k)) {
return TopkV3(self, kTensor, dim, largest, sorted, valuesOut, indicesOut, executor);
} else {
if (IsRegBase()) {
return TopkV2AiCoreForDavid(
self, kTensor, dim, largest, sorted, valuesOut, indicesOut, indicesDType, executor);
} else {
return TopkV2AiCore(self, kTensor, dim, largest, sorted, valuesOut, indicesOut, executor);
}
}
} else {
aclTensor* indicesInt32 = executor->AllocTensor(outShape, op::DataType::DT_INT32, self->GetStorageFormat());
auto result = TopkAiCpu(self, kTensor, dim, largest, sorted, valuesOut, indicesInt32, executor);
if (std::get<0>(result) == nullptr || std::get<1>(result) == nullptr) {
return result;
}
if (IsRegBase() && indicesDType != op::DataType::DT_INT32) {
auto indicesCast = Cast(indicesInt32, indicesDType, executor);
return std::tuple<aclTensor*, aclTensor*>(valuesOut, const_cast<aclTensor*>(indicesCast));
}
return std::tuple<aclTensor*, aclTensor*>(valuesOut, indicesInt32);
}
}
}