* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "c_interface_utils.h"
#include "atb/utils/config.h"
#include "atb/utils/singleton.h"
#include "atb/utils/log.h"
using namespace atb;
using namespace atb::cinterfaceTest;
const int64_t INOUT_TENSOR_NUM = 6;
const int ACTIVATION_SIGMOID = 8;
const int MAPPING_NUM_INDEX = 2;
const int MAPPING_TABLE_INDEX = 3;
void TestFusedAddTopK(const int64_t batchSize, const int64_t expertNum, const int maxRedundantExpertNum,
const uint32_t groupNum, const uint32_t groupTopk, const uint32_t n, const uint32_t k,
const int activationType, const bool isNorm, const float scale, const bool enableExpertMapping,
const aclDataType dtype)
{
atb::Context *context = nullptr;
aclrtStream stream = nullptr;
int64_t deviceId = 0;
Init(&context, &stream, &deviceId);
if (!atb::GetSingleton<atb::Config>().Is910B()) {
ATB_LOG(ERROR) << "FusedAddTopK only supports A2/A3";
Destroy(&context, &stream);
GTEST_SKIP();
}
uint8_t *inoutHost[INOUT_TENSOR_NUM];
uint8_t *inoutDevice[INOUT_TENSOR_NUM];
aclTensor *tensorList[INOUT_TENSOR_NUM];
std::vector<aclDataType> inputTypes = {dtype, dtype, ACL_INT32, ACL_INT32, ACL_FLOAT, ACL_INT32};
std::vector<std::vector<int64_t>> tensorDim = {
{batchSize, expertNum},
{expertNum},
{expertNum},
{expertNum, maxRedundantExpertNum},
{batchSize, k},
{batchSize, k},
};
size_t inoutSize[INOUT_TENSOR_NUM];
int total = 0;
for (int i = 0; i < INOUT_TENSOR_NUM; ++i) {
if (tensorDim[i].size() == 0) {
inoutSize[i] = 0;
continue;
}
total = 1;
for (int j = 0; j < tensorDim[i].size(); ++j) {
total *= tensorDim[i][j];
}
inoutSize[i] = total * aclDataTypeSize(inputTypes[i]);
}
CreateInOutData(INOUT_TENSOR_NUM, inoutHost, inoutDevice, inoutSize);
size_t i = 0;
while (i < tensorDim.size()) {
if (!enableExpertMapping && (i == MAPPING_NUM_INDEX || i == MAPPING_TABLE_INDEX)) {
tensorList[i] = nullptr;
++i;
continue;
}
CreateACLTensorInOut(tensorDim[i], inputTypes[i], ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
}
uint64_t workspaceSize = 0;
atb::Operation *op = nullptr;
Status ret = AtbFusedAddTopkDivGetWorkspaceSize(
tensorList[0], tensorList[1], tensorList[2], tensorList[3], groupNum, groupTopk, n, k, activationType, isNorm,
scale, enableExpertMapping, tensorList[4], tensorList[5], &workspaceSize, &op, context);
EXPECT_EQ(ret, atb::NO_ERROR);
void *workspaceAddr = nullptr;
if (workspaceSize > 0) {
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
EXPECT_EQ(ret, ACL_SUCCESS);
}
ret = AtbFusedAddTopkDiv(workspaceAddr, workspaceSize, op, context);
EXPECT_EQ(ret, atb::NO_ERROR);
ret = aclrtSynchronizeStream(stream);
EXPECT_EQ(ret, ACL_SUCCESS);
if (workspaceSize > 0) {
EXPECT_EQ(aclrtFree(workspaceAddr), ACL_SUCCESS);
}
EXPECT_EQ(atb::DestroyOperation(op), NO_ERROR);
Destroy(&context, &stream);
for (i = 0; i < INOUT_TENSOR_NUM; i++) {
aclrtFreeHost(inoutHost[i]);
aclrtFree(inoutDevice[i]);
}
}
TEST(TestATBACL, TestFusedAddTopK1WithMapping)
{
TestFusedAddTopK(10, 200, 32, 8, 2, 10, 5, ACTIVATION_SIGMOID, true, 0.0887, true, ACL_FLOAT16);
}
TEST(TestATBACL, TestFusedAddTopK1NoMapping)
{
TestFusedAddTopK(6, 256, 8, 8, 2, 5, 5, ACTIVATION_SIGMOID, true, 0.0887, false, ACL_FLOAT16);
}