* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "c_interface_utils.h"
#include "atb/utils/config.h"
#include "atb/utils/singleton.h"
#include "atb/utils/log.h"
using namespace atb;
using namespace atb::cinterfaceTest;
const int64_t MLAPPREFILLINOUT = 9;
const int64_t maxSeqLen = 256;
const int64_t embeddimV = 128;
const int64_t sizeofFP16 = 2;
const int64_t kvHeads = 1;
const int64_t numTokens = 32;
TEST(TestATBACL, TestMLAPreFillM0C2C1)
{
atb::Context *context = nullptr;
aclrtStream stream = nullptr;
int64_t deviceId = 0;
int64_t batch = 4;
Init(&context, &stream, &deviceId);
if (!atb::GetSingleton<atb::Config>().Is910B()) {
ATB_LOG(ERROR) << "MLA prefill only supports A2/A3";
Destroy(&context, &stream);
GTEST_SKIP();
}
uint8_t *inoutHost[MLAPPREFILLINOUT];
uint8_t *inoutDevice[MLAPPREFILLINOUT];
aclTensor *tensorList[MLAPPREFILLINOUT];
size_t inoutSize[MLAPPREFILLINOUT] = {
numTokens * kvHeads * embeddimV * sizeofFP16,
numTokens * kvHeads * 64 * sizeofFP16,
batch * maxSeqLen * kvHeads * embeddimV * sizeofFP16,
batch * maxSeqLen * kvHeads * 64 * sizeofFP16,
batch * maxSeqLen * kvHeads * embeddimV * sizeofFP16,
batch * sizeof(int),
batch * sizeof(int),
512 * 512 * sizeofFP16,
numTokens * kvHeads * embeddimV * sizeofFP16,
};
CreateInOutData(9, inoutHost, inoutDevice, inoutSize);
size_t i = 0;
aclDataType inputFormat = ACL_FLOAT16;
std::vector<std::vector<int64_t>> viewDim = {
{numTokens, kvHeads, embeddimV},
{numTokens, kvHeads, 64},
{batch, maxSeqLen, kvHeads * embeddimV},
{batch, maxSeqLen, kvHeads * 64},
{batch, maxSeqLen, kvHeads * embeddimV},
{batch},
{batch},
{512, 512},
{numTokens, kvHeads, embeddimV},
};
while (i < viewDim.size()) {
if (i == 5 || i == 6) {
std::vector<int32_t> seqlen = {4, 4, 4, 4};
CreateACLTensorInOut(viewDim[i], ACL_INT32, ACL_FORMAT_ND, tensorList, i, seqlen.data());
} else {
CreateACLTensorInOut(viewDim[i], inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
}
}
uint64_t workspaceSize = 0;
atb::Operation *op = nullptr;
Status ret = AtbMLAPreFillGetWorkspaceSize(tensorList[0], tensorList[1], tensorList[2], tensorList[3],
tensorList[4], tensorList[5], tensorList[6], tensorList[7], kvHeads, 1.0,
kvHeads, 2, 1, tensorList[8], &workspaceSize, &op, context);
EXPECT_EQ(ret, atb::NO_ERROR);
void *workspaceAddr = nullptr;
if (workspaceSize > 0) {
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
EXPECT_EQ(ret, ACL_SUCCESS);
}
ret = AtbMLAPreFill(workspaceAddr, workspaceSize, op, context);
EXPECT_EQ(ret, atb::NO_ERROR);
ret = aclrtSynchronizeStream(stream);
if (workspaceSize > 0) {
EXPECT_EQ(aclrtFree(workspaceAddr), ACL_SUCCESS);
}
EXPECT_EQ(atb::DestroyOperation(op), NO_ERROR);
Destroy(&context, &stream);
for (i = 0; i < MLAPPREFILLINOUT; i++) {
aclrtFreeHost(inoutHost[i]);
aclrtFree(inoutDevice[i]);
}
}