* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "c_interface_utils.h"
#include "atb/utils/config.h"
#include "atb/utils/singleton.h"
#include "atb/utils/log.h"
using namespace atb;
using namespace atb::cinterfaceTest;
const int64_t MLAINOUTMLAPP = 28;
const int64_t dims = 7168;
const int64_t dimB = 2112;
const int64_t dimC = 1536;
const int64_t dimD = 1;
const int64_t quantScale0 = 0;
const int64_t quantOffset0 = 0;
const int64_t wdqkv = 1 * 224 * 2112 * 32;
const int64_t deScale0 = 2112;
const int64_t bias0 = 2112;
const int64_t beta1 = 1536;
const int64_t gamma1 = 1536;
const int64_t quantScale1 = 1;
const int64_t quantOffset1 = 1;
const int64_t blockSize = 128;
const int64_t numTokens = 32;
const int64_t numHeads = 32;
const int64_t kvHeads = 1;
const int64_t kSeqlen = 256;
const int64_t batch = numTokens * kSeqlen;
const int64_t numBlocks = 64;
const int64_t sizeofFP16 = 2;
const int64_t wuq = 1 * 48 * numHeads * 192 * 32;
const int64_t deScale1 = numHeads * 192;
const int64_t bias1 = numHeads * 192;
const int64_t gamma2 = 512;
const int64_t cosNum = numTokens * 64;
const int64_t sinNum = numTokens * 64;
const int64_t wuk = numHeads * 128 * 512;
const int64_t kvCache = numBlocks * blockSize * 1 * 512;
const int64_t kvCacheC0 = numBlocks * blockSize * 1 * 576;
const int64_t kvCacheC2 = numBlocks * numHeads * 512 / 32 * blockSize * 32;
const int64_t kvCacheC3 = numBlocks * numHeads * 512 / 16 * blockSize * 16;
const int64_t kvCacheRope = numBlocks * blockSize * 1 * 64;
const int64_t kvCacheRopeC2 = numBlocks * numHeads * 64 / 16 * blockSize * 16;
const int64_t kvCacheRopeC3 = numBlocks * numHeads * 64 / 16 * blockSize * 16;
const int64_t slotmapping = numTokens;
const int64_t ctkvScale = 1;
const int64_t qNopeScale = numHeads;
const int64_t outTensor0C0 = numTokens * numHeads * 576;
const int64_t outTensor0C1 = numTokens * numHeads * 512;
const int64_t outTensor0C2 = numTokens * numHeads * 512;
const int64_t outTensor1C0 = numBlocks * blockSize * 576;
const int64_t outTensor1C1 = numBlocks * blockSize * 512;
const int64_t outTensor1C2 = numBlocks * numHeads * 512 / 32 * blockSize * 32;
const int64_t outTensor1C3 = numBlocks * numHeads * 512 / 16 * blockSize * 16;
const int64_t outTensor2 = numTokens * numHeads * 64;
const int64_t outTensor3C1 = numBlocks * blockSize * 1 * 64;
const int64_t outTensor3C2 = numBlocks * numHeads * 64 / 16 * blockSize * 16;
TEST(TestATBACL, TestMLAPreProcesscomb0Q0C0)
{
atb::Context *context = nullptr;
aclrtStream stream = nullptr;
int64_t deviceId = 0;
Init(&context, &stream, &deviceId);
if (!atb::GetSingleton<atb::Config>().Is910B()) {
ATB_LOG(ERROR) << "MLA PreProcess only supports A2/A3";
Destroy(&context, &stream);
GTEST_SKIP();
}
uint8_t *inoutHost[MLAINOUTMLAPP];
uint8_t *inoutDevice[MLAINOUTMLAPP];
aclTensor *tensorList[MLAINOUTMLAPP];
size_t inoutSize[MLAINOUTMLAPP] = {
numTokens * dims * sizeofFP16,
dims * sizeofFP16,
dims * sizeofFP16,
quantScale0 * sizeofFP16,
quantOffset0,
wdqkv,
deScale0 * sizeof(int64_t),
bias0 * sizeof(int32_t),
gamma1 * sizeofFP16,
beta1 * sizeofFP16,
quantScale1 * sizeofFP16,
quantOffset1,
wuq,
deScale1 * sizeof(float),
bias1 * sizeof(int32_t),
gamma2 * sizeofFP16,
cosNum * sizeofFP16,
sinNum * sizeofFP16,
wuk * sizeofFP16,
kvCacheC0 * sizeofFP16,
kvCacheRope * sizeofFP16,
slotmapping * sizeof(int32_t),
ctkvScale * sizeofFP16,
qNopeScale * sizeofFP16,
outTensor0C0 * sizeofFP16,
outTensor1C0 * sizeofFP16,
outTensor2 * sizeofFP16,
outTensor1C3 * sizeofFP16,
};
CreateInOutData(MLAINOUTMLAPP, inoutHost, inoutDevice, inoutSize);
size_t i = 0;
aclDataType inputFormat = ACL_FLOAT16;
std::vector<int64_t> viewDim = {numTokens, dims};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {dims};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {dims};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {dimD};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {dimD};
CreateACLTensorInOut(viewDim, ACL_INT8, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {1, 224, dimB, 32};
CreateACLTensorInOut(viewDim, ACL_INT8, ACL_FORMAT_FRACTAL_NZ, tensorList, i, inoutDevice[i]);
viewDim = {dimB};
CreateACLTensorInOut(viewDim, ACL_INT64, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {dimB};
CreateACLTensorInOut(viewDim, ACL_INT32, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {dimC};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {dimC};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {dimD};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {dimD};
CreateACLTensorInOut(viewDim, ACL_INT8, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {1, 48, numHeads * 192, 32};
CreateACLTensorInOut(viewDim, ACL_INT8, ACL_FORMAT_FRACTAL_NZ, tensorList, i, inoutDevice[i]);
viewDim = {numHeads * 192};
CreateACLTensorInOut(viewDim, ACL_INT64, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {numHeads * 192};
CreateACLTensorInOut(viewDim, ACL_INT32, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {gamma2};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {numTokens, 64};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {numTokens, 64};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {numHeads, 128, 512};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {numBlocks, blockSize, 1, 576};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {numBlocks, blockSize, 1, 64};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {numTokens};
CreateACLTensorInOut(viewDim, ACL_INT32, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {ctkvScale};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {qNopeScale};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {numTokens, numHeads, 576};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {numBlocks, blockSize, 1, 576};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {numTokens, numHeads, 64};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
viewDim = {numBlocks, blockSize, 1, 64};
CreateACLTensorInOut(viewDim, inputFormat, ACL_FORMAT_ND, tensorList, i, inoutDevice[i]);
uint64_t workspaceSize = 0;
atb::Operation *op = nullptr;
atb::Status ret = AtbMLAPreprocessGetWorkspaceSize(
tensorList[0], tensorList[1], tensorList[2], tensorList[3], tensorList[4], tensorList[5], tensorList[6],
tensorList[7], tensorList[8], tensorList[9], tensorList[10], tensorList[11], tensorList[12], tensorList[13],
tensorList[14], tensorList[15], tensorList[16], tensorList[17], tensorList[18], tensorList[19], tensorList[20],
tensorList[21], tensorList[22], tensorList[23], 0, 0, 0, 1e-5, 2, 3, true, true, true, 0, 0, tensorList[24],
tensorList[25], tensorList[26], tensorList[27], &workspaceSize, &op, context);
EXPECT_EQ(ret, atb::NO_ERROR);
void *workspaceAddr = nullptr;
if (workspaceSize > 0) {
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
EXPECT_EQ(ret, ACL_SUCCESS);
}
ret = AtbMLAPreprocess(workspaceAddr, workspaceSize, op, context);
EXPECT_EQ(ret, atb::NO_ERROR);
ret = aclrtSynchronizeStream(stream);
if (workspaceSize > 0) {
EXPECT_EQ(aclrtFree(workspaceAddr), ACL_SUCCESS);
}
EXPECT_EQ(atb::DestroyOperation(op), atb::NO_ERROR);
Destroy(&context, &stream);
for (i = 0; i < MLAINOUTMLAPP; i++) {
aclrtFreeHost(inoutHost[i]);
aclrtFree(inoutDevice[i]);
}
}