* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../demo_util.h"
const int32_t DEVICE_ID = 2;
const uint32_t blockSize = 128;
const uint32_t blockNum = 64;
* @brief 准备atb::VariantPack中的输入tensor
* @param contextPtr context指针
* @param stream stream
* @param inTensors atb::SVector<atb::Tensor> *atb::VariantPack中的输入tensor
* @return atb::Status 错误码
*/
atb::Status PrepareInTensor1(atb::Context *contextPtr, aclrtStream stream, aclDataType dtype, int tokenNum,
atb::SVector<atb::Tensor> *inTensors)
{
atb::Tensor input;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(tokenNum * 7168, 0), dtype,
aclFormat::ACL_FORMAT_ND, {tokenNum, 7168}, input));
atb::Tensor gamma0;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(7168, 0), dtype,
aclFormat::ACL_FORMAT_ND, {7168}, gamma0));
atb::Tensor beta0;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(7168, 0), dtype,
aclFormat::ACL_FORMAT_ND, {7168}, beta0));
atb::Tensor quantScale0;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(1, 0), dtype, aclFormat::ACL_FORMAT_ND,
{1}, quantScale0));
atb::Tensor quantOffset0;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<int8_t>(1, 1), ACL_INT8,
aclFormat::ACL_FORMAT_ND, {1}, quantOffset0));
atb::Tensor wdqkv;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<int8_t>(224 * 2112 * 32, 1), ACL_INT8,
aclFormat::ACL_FORMAT_FRACTAL_NZ, {1, 224, 2112, 32}, wdqkv));
atb::Tensor deScale0;
if (dtype == ACL_BF16) {
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(2112, 1), ACL_FLOAT,
aclFormat::ACL_FORMAT_ND, {2112}, deScale0));
} else {
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<int64_t>(2112, 10), ACL_INT64,
aclFormat::ACL_FORMAT_ND, {2112}, deScale0));
}
atb::Tensor bias0;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<int32_t>(2112, 1), ACL_INT32,
aclFormat::ACL_FORMAT_ND, {2112}, bias0));
atb::Tensor gamma1;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(1536, 0), dtype,
aclFormat::ACL_FORMAT_ND, {1536}, gamma1));
atb::Tensor beta1;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(1536, 0), dtype,
aclFormat::ACL_FORMAT_ND, {1536}, beta1));
atb::Tensor quantScale1;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(1, 0), dtype, aclFormat::ACL_FORMAT_ND,
{1}, quantScale1));
atb::Tensor quantOffset1;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<int8_t>(1, 1), ACL_INT8,
aclFormat::ACL_FORMAT_ND, {1}, quantOffset1));
*inTensors = {input, gamma0, beta0, quantScale0, quantOffset0, wdqkv,
deScale0, bias0, gamma1, beta1, quantScale1, quantOffset1};
return atb::ErrorType::NO_ERROR;
}
* @brief 准备atb::VariantPack中的输入tensor
* @param contextPtr context指针
* @param stream stream
* @param inTensors atb::SVector<atb::Tensor> *atb::VariantPack中的输入tensor
* @return atb::Status 错误码
*/
atb::Status PrepareInTensor2(atb::Context *contextPtr, aclrtStream stream, aclDataType dtype, int tokenNum, int headNum,
atb::SVector<atb::Tensor> *inTensors)
{
atb::Tensor wuq;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<int8_t>(48 * headNum * 192 * 32, 1), ACL_INT8,
aclFormat::ACL_FORMAT_FRACTAL_NZ, {1, 48, headNum * 192, 32}, wuq));
atb::Tensor deScale1;
if (dtype == ACL_BF16) {
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(headNum * 192, 1), ACL_FLOAT,
aclFormat::ACL_FORMAT_ND, {headNum * 192}, deScale1));
} else {
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<int64_t>(headNum * 192, 10), ACL_INT64,
aclFormat::ACL_FORMAT_ND, {headNum * 192}, deScale1));
}
atb::Tensor bias1;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<int32_t>(headNum * 192, 1), ACL_INT32,
aclFormat::ACL_FORMAT_ND, {headNum * 192}, bias1));
atb::Tensor gamma2;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(512, 0), dtype,
aclFormat::ACL_FORMAT_ND, {512}, gamma2));
atb::Tensor cos;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(tokenNum * 64, 0), dtype,
aclFormat::ACL_FORMAT_ND, {tokenNum, 64}, cos));
atb::Tensor sin;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(tokenNum * 64, 0.5), dtype,
aclFormat::ACL_FORMAT_ND, {tokenNum, 64}, sin));
atb::Tensor wuk;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(headNum * 32 * 128 * 16, 0), dtype,
aclFormat::ACL_FORMAT_FRACTAL_NZ, {headNum, 32, 128, 16}, wuk));
atb::Tensor kvCache;
CHECK_STATUS(CreateTensorFromVector(
contextPtr, stream, std::vector<int8_t>(blockNum * headNum * 512 * blockSize, 1), ACL_INT8,
aclFormat::ACL_FORMAT_FRACTAL_NZ, {blockNum, headNum * 512 / 32, blockSize, 32}, kvCache));
atb::Tensor kvCacheRope;
CHECK_STATUS(CreateTensorFromVector(
contextPtr, stream, std::vector<float>(blockNum * headNum * 64 / 16 * blockSize * 16, 0), dtype,
aclFormat::ACL_FORMAT_FRACTAL_NZ, {blockNum, headNum * 64 / 16, blockSize, 16}, kvCacheRope));
auto slotmappingHost = std::vector<int32_t>(1, tokenNum);
for (size_t i = 0; i < slotmappingHost.size(); i++)
slotmappingHost[i] = static_cast<int32_t>(i);
atb::Tensor slotmapping;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, slotmappingHost, ACL_INT32, aclFormat::ACL_FORMAT_ND,
{tokenNum}, slotmapping));
atb::Tensor ctkvScale;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(1, 0), dtype, aclFormat::ACL_FORMAT_ND,
{1}, ctkvScale));
atb::Tensor qNopeScale;
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, std::vector<float>(headNum, 0), dtype,
aclFormat::ACL_FORMAT_ND, {headNum}, qNopeScale));
atb::SVector<atb::Tensor> tempTensors = {wuq, deScale1, bias1, gamma2, cos, sin,
wuk, kvCache, kvCacheRope, slotmapping, ctkvScale, qNopeScale};
for (auto &tensor : tempTensors) {
inTensors->push_back(tensor);
}
return atb::ErrorType::NO_ERROR;
}
* @brief 创建一个MlaPreprocessOperation operation
* @param atb::Operation * 返回一个Operation指针
*/
atb::Status CreateMlaPreprocessOperation(atb::Operation **mlaPreprocessOp)
{
atb::infer::MlaPreprocessParam param;
param.cacheMode = atb::infer::MlaPreprocessParam::CacheMode::INT8_NZCACHE;
return atb::CreateOperation(param, mlaPreprocessOp);
}
* @brief 进行MlaPreprocessOperation的循环调用
* @param context context指针
* @param stream stream
* @param dtype 指定部分输入/输出vector数据类型
* @param tokenNum 词元数
* @param headNum 头数
* @return atb::Status 错误码
*/
atb::Status RunDemo(atb::Context *context, void *stream, aclDataType dtype, int tokenNum, int headNum)
{
atb::Operation *mlaPreprocessOp = nullptr;
CHECK_STATUS(CreateMlaPreprocessOperation(&mlaPreprocessOp));
atb::VariantPack variantPack;
CHECK_STATUS(PrepareInTensor1(context, stream, dtype, tokenNum, &variantPack.inTensors));
CHECK_STATUS(PrepareInTensor2(context, stream, dtype, tokenNum, headNum, &variantPack.inTensors));
atb::Tensor qOut0;
CreateTensor(ACL_INT8, aclFormat::ACL_FORMAT_ND, {tokenNum, headNum, 512}, qOut0);
atb::Tensor &kvCacheOut0 = variantPack.inTensors.at(19);
atb::Tensor qOut1;
CreateTensor(dtype, aclFormat::ACL_FORMAT_ND, {tokenNum, headNum, 64}, qOut1);
atb::Tensor &kvCacheOut1 = variantPack.inTensors.at(20);
variantPack.outTensors = {qOut0, kvCacheOut0, qOut1, kvCacheOut1};
uint64_t workspaceSize = 0;
CHECK_STATUS(mlaPreprocessOp->Setup(variantPack, workspaceSize, context));
uint8_t *workspacePtr = nullptr;
if (workspaceSize > 0) {
CHECK_STATUS(aclrtMalloc((void **)(&workspacePtr), workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
}
for (size_t i = 0; i < 10; i++) {
std::cout << "tokenNum: " << tokenNum << " headNum: " << headNum << " loop: " << i << std::endl;
CHECK_STATUS(mlaPreprocessOp->Execute(variantPack, workspacePtr, workspaceSize, context));
CHECK_STATUS(aclrtSynchronizeStream(stream));
}
for (atb::Tensor &inTensor : variantPack.inTensors) {
CHECK_STATUS(aclrtFree(inTensor.deviceData));
for (atb::Tensor &outTensor : variantPack.outTensors) {
if (outTensor.deviceData == inTensor.deviceData) {
outTensor.deviceData = nullptr;
}
}
inTensor.deviceData = nullptr;
}
for (atb::Tensor &outTensor : variantPack.outTensors) {
if (outTensor.deviceData == nullptr)
continue;
CHECK_STATUS(aclrtFree(outTensor.deviceData));
}
if (workspaceSize > 0) {
CHECK_STATUS(aclrtFree(workspacePtr));
}
return atb::DestroyOperation(mlaPreprocessOp);
}
int main(int argc, char **argv)
{
std::string dtypeStr;
int tokenNum = 4;
int headNum = 128;
aclDataType dtype = ACL_FLOAT16;
if (argc == 4) {
dtypeStr = argv[1];
tokenNum = std::stoi(argv[2]);
headNum = std::stoi(argv[3]);
}
if (dtypeStr == "bf16") {
dtype = ACL_BF16;
}
atb::Context *context = nullptr;
void *stream = nullptr;
CHECK_STATUS(aclInit(nullptr));
CHECK_STATUS(aclrtSetDevice(DEVICE_ID));
CHECK_STATUS(atb::CreateContext(&context));
CHECK_STATUS(aclrtCreateStream(&stream));
context->SetExecuteStream(stream);
RunDemo(context, stream, dtype, tokenNum, headNum);
CHECK_STATUS(aclrtDestroyStream(stream));
CHECK_STATUS(DestroyContext(context));
CHECK_STATUS(aclFinalize());
std::cout << "MlaPreprocess demo success!" << std::endl;
return 0;
}