* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../demo_util.h"
#include <random>
const uint32_t BATCH_SIZE = 3;
std::vector<int32_t> seqLenHost = {100, 1000, 128, 700, 0, 128};
const uint32_t Q_NTOKENS = accumulate(seqLenHost.begin(), seqLenHost.begin() + BATCH_SIZE, 0);
const uint32_t kV_NTOKENS = accumulate(seqLenHost.begin() + BATCH_SIZE, seqLenHost.end(), 0);
const uint32_t HEAD_NUM = 16;
const uint32_t KV_HEAD_NUM = 8;
const uint32_t NOPE_HEAD_SIZE = 128;
const uint32_t ROPE_HEAD_SIZE = 64;
const uint32_t MASK_SEQ_LEN = 512;
const int32_t MASK_INDEX = 5;
const int32_t SEQLEN_INDEX = 6;
* @brief 随机填充inData数值,值域为[low, high)
* @param inData 需要随机填充的vector
* @param low 最小值
* @param high 最大值
*/
void AssignRandomValue(std::vector<float> &inData, int low = -5, int high = -5)
{
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(low, high);
for (size_t i = 0; i < inData.size(); ++i) {
inData[i] = dis(gen);
}
}
* @brief 准备atb::VariantPack中的所有输入tensor
* @param contextPtr context指针
* @param stream stream
* @param seqLenHost host侧tensor。序列长度向量,等于1时,为增量或全量;大于1时,为全量
* @param inTensors atb::VariantPack中的输入tensor
* @return atb::Status atb错误码
* @note 需要传入所有host侧tensor
*/
atb::Status PrepareFirstRingInTensor(atb::Context *contextPtr, aclrtStream stream, std::vector<int32_t> &seqLenHost,
atb::SVector<atb::Tensor> &inTensors)
{
atb::Tensor tensorQNope;
atb::Tensor tensorQRope;
atb::Tensor tensorKNope;
atb::Tensor tensorKRope;
atb::Tensor tensorV;
atb::Tensor tensorMask;
std::vector<float> maskData(MASK_SEQ_LEN * MASK_SEQ_LEN, 0);
for (int i = 0; i < MASK_SEQ_LEN; ++i) {
for (int j = i + 1; j < MASK_SEQ_LEN; ++j) {
maskData[i * MASK_SEQ_LEN + j] = 1;
}
}
atb::Tensor tensorSeqLen;
std::vector<std::vector<int64_t>> tensorDim = {
{Q_NTOKENS, HEAD_NUM, NOPE_HEAD_SIZE},
{Q_NTOKENS, HEAD_NUM, ROPE_HEAD_SIZE},
{kV_NTOKENS, KV_HEAD_NUM, NOPE_HEAD_SIZE},
{kV_NTOKENS, KV_HEAD_NUM, ROPE_HEAD_SIZE},
{kV_NTOKENS, KV_HEAD_NUM, NOPE_HEAD_SIZE},
{MASK_SEQ_LEN, MASK_SEQ_LEN},
{2, BATCH_SIZE},
};
std::vector<float> randomData = {};
int64_t totalSize = 0;
inTensors = {tensorQNope, tensorQRope, tensorKNope, tensorKRope, tensorV, tensorMask, tensorSeqLen};
for (int32_t i = 0; i <= SEQLEN_INDEX; ++i) {
if (i == MASK_INDEX) {
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, maskData, ACL_BF16, aclFormat::ACL_FORMAT_ND,
{MASK_SEQ_LEN, MASK_SEQ_LEN}, inTensors[i]));
} else if (i == SEQLEN_INDEX) {
CHECK_STATUS(CreateTensor(ACL_INT32, aclFormat::ACL_FORMAT_ND, {2, BATCH_SIZE}, inTensors[i]));
inTensors[i].hostData = seqLenHost.data();
} else {
totalSize = 1;
for (int j = 0; j < tensorDim[i].size(); ++j) {
totalSize *= tensorDim[i][j];
}
randomData.reserve(totalSize);
randomData.resize(totalSize);
AssignRandomValue(randomData);
CHECK_STATUS(CreateTensorFromVector(contextPtr, stream, randomData, ACL_BF16, aclFormat::ACL_FORMAT_ND,
tensorDim[i], inTensors[i]));
}
}
return atb::ErrorType::NO_ERROR;
}
atb::Status RunRingMLADemo(atb::Context *contextPtr, aclrtStream stream, atb::Operation *ringMLAOp)
{
atb::infer::RingMLAParam ringMLAParam;
ringMLAParam.calcType = atb::infer::RingMLAParam::CalcType::CALC_TYPE_FISRT_RING;
ringMLAParam.headNum = HEAD_NUM;
ringMLAParam.kvHeadNum = KV_HEAD_NUM;
ringMLAParam.qkScale = 1 / sqrt(NOPE_HEAD_SIZE + ROPE_HEAD_SIZE);
ringMLAParam.kernelType = atb::infer::RingMLAParam::KernelType::KERNELTYPE_HIGH_PRECISION;
ringMLAParam.maskType = atb::infer::RingMLAParam::MaskType::MASK_TYPE_TRIU;
CHECK_STATUS(atb::CreateOperation(ringMLAParam, &ringMLAOp));
atb::VariantPack ringMLAVariantPack;
ringMLAVariantPack.inTensors;
CHECK_STATUS(
PrepareFirstRingInTensor(contextPtr, stream, seqLenHost, ringMLAVariantPack.inTensors));
atb::Tensor tensorOutput;
atb::Tensor tensorSoftmaxLse;
CreateTensor(ACL_BF16, aclFormat::ACL_FORMAT_ND, {Q_NTOKENS, HEAD_NUM, NOPE_HEAD_SIZE}, tensorOutput);
CreateTensor(ACL_FLOAT, aclFormat::ACL_FORMAT_ND, {HEAD_NUM, Q_NTOKENS}, tensorSoftmaxLse);
ringMLAVariantPack.outTensors = {tensorOutput, tensorSoftmaxLse};
uint64_t workspaceSize = 0;
CHECK_STATUS(ringMLAOp->Setup(ringMLAVariantPack, workspaceSize, contextPtr));
uint8_t *workspacePtr = nullptr;
if (workspaceSize > 0) {
CHECK_STATUS(aclrtMalloc((void **)(&workspacePtr), workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
}
CHECK_STATUS(ringMLAOp->Execute(ringMLAVariantPack, workspacePtr, workspaceSize, contextPtr));
CHECK_STATUS(aclrtSynchronizeStream(stream));
ringMLAParam.calcType = atb::infer::RingMLAParam::CalcType::CALC_TYPE_DEFAULT;
CHECK_STATUS(atb::CreateOperation(ringMLAParam, &ringMLAOp));
ringMLAVariantPack.inTensors.push_back(tensorOutput);
ringMLAVariantPack.inTensors.push_back(tensorSoftmaxLse);
CHECK_STATUS(ringMLAOp->Setup(ringMLAVariantPack, workspaceSize, contextPtr));
CHECK_STATUS(ringMLAOp->Execute(ringMLAVariantPack, workspacePtr, workspaceSize, contextPtr));
CHECK_STATUS(aclrtSynchronizeStream(stream));
if (workspaceSize > 0) {
CHECK_STATUS(aclrtFree(workspacePtr));
}
for (atb::Tensor &inTensor : ringMLAVariantPack.inTensors) {
CHECK_STATUS(aclrtFree(inTensor.deviceData));
}
return atb::ErrorType::NO_ERROR;
}
int main(int argc, char **argv)
{
int32_t deviceId = 0;
if (argc == 2) {
deviceId = std::stoi(argv[1]);
}
CHECK_STATUS(aclInit(nullptr));
CHECK_STATUS(aclrtSetDevice(deviceId));
atb::Context *context = nullptr;
CHECK_STATUS(atb::CreateContext(&context));
void *stream = nullptr;
CHECK_STATUS(aclrtCreateStream(&stream));
context->SetExecuteStream(stream);
atb::Operation *ringMLAOp = nullptr;
RunRingMLADemo(context, stream, ringMLAOp);
CHECK_STATUS(atb::DestroyOperation(ringMLAOp));
CHECK_STATUS(aclrtDestroyStream(stream));
CHECK_STATUS(DestroyContext(context));
CHECK_STATUS((aclFinalize()));
std::cout << "RingMLA demo success!" << std::endl;
return 0;
}