aclnnMlaPreprocessV2
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | × |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
接口功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程如下:
- 首先对输入xx RmsNormQuant后乘以WDQKVW^{DQKV}进行下采样后分为通路1和通路2。
- 通路1做RmsNormQuant后乘以WUQW^{UQ}后再分为通路3和通路4。
- 通路3后乘以WukW^{uk}后输出qNq^N。
- 通路4后经过旋转位置编码后输出qRq^R。
- 通路2拆分为通路5和通路6。
- 通路5经过RmsNorm后传入Cache中得到kNk^N。
- 通路6经过旋转位置编码后传入另一个Cache中得到kRk^R。
-
计算流程图

-
计算公式:
RmsNormQuant公式
RMS(x)=1N∑i=1Nxi2+ϵ\text{RMS}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon}
RmsNorm(x)=γ⋅xiRMS(x)\text{RmsNorm}(x) = \gamma \cdot \frac{x_i}{\text{RMS}(x)}
RmsNormQuant(x)=(RmsNorm(x)+bias)∗deqScaleRmsNormQuant(x) = ({RmsNorm}(x) + bias) * deqScale
Query计算公式,包括W^{DQKV}矩阵乘、W^{UK}矩阵乘、RmsNormQuant和ROPE旋转位置编码处理
qN=RmsNormQuant(x)⋅WDQKV⋅WUKq^N = RmsNormQuant(x) \cdot W^{DQKV} \cdot W^{UK}
qR=ROPE(xQ)q^R = ROPE(x^Q)
Key计算公式,包括RmsNorm和rope,将计算结果存入cache
kN=Cache(RmsNorm(RmsNormQuant(x)))k^N = Cache({RmsNorm}(RmsNormQuant(x)))
kR=Cache(ROPE(RmsNormQuant(x)))k^R = Cache(ROPE(RmsNormQuant(x)))
函数原型
每个算子分为两段式接口,必须先调用“aclnnMlaPreprocessV2GetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnMlaPreprocessV2”接口执行计算。
aclnnStatus aclnnMlaPreprocessV2GetWorkspaceSize(
const aclTensor *input,
const aclTensor *gamma0,
const aclTensor *beta0,
const aclTensor *quantScale0,
const aclTensor *quantOffset0,
const aclTensor *wdqkv,
const aclTensor *deScale0,
const aclTensor *bias0,
const aclTensor *gamma1,
const aclTensor *beta1,
const aclTensor *quantScale1,
const aclTensor *quantOffset1,
const aclTensor *wuq,
const aclTensor *deScale1,
const aclTensor *bias1,
const aclTensor *gamma2,
const aclTensor *cos,
const aclTensor *sin,
const aclTensor *wuk,
const aclTensor *kvCache,
const aclTensor *kvCacheRope,
const aclTensor *slotMapping,
const aclTensor *ctkvScale,
const aclTensor *qNopeScale,
int64_t wdqDim,
int64_t qRopeDim,
int64_t kRopeDim,
double epsilon,
int64_t qRotaryCoeff,
int64_t kRotaryCoeff,
bool transposeWdq,
bool transposeWuq,
bool transposeWuk,
int64_t cacheMode,
int64_t quantMode,
bool doRmsNorm,
int64_t wdkvSplitCount,
bool qDownOutFlag,
const aclTensor *qOut,
const aclTensor *kvCacheOut,
const aclTensor *qRopeOut,
const aclTensor *krCacheOut,
const aclTensor *qDownOut,
uint64_t *workspaceSize,
aclOpExecutor **executor)
aclnnStatus aclnnMlaPreprocessV2(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream)
aclnnMlaPreprocessV2GetWorkspaceSize
-
参数说明
参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续tensor input 输入 用于计算Query和Key的x。 - FLOAT16、BFLOAT16 ND [tokenNum,hiddenSize] - gamma0 输入 首次RmsNorm计算中的γ参数。 数据类型需要与input满足数据类型推导规则(参见互推导关系和约束说明)。 FLOAT16、BFLOAT16 ND [hiddenSize] - beta0 输入 首次RmsNorm计算中的β参数。 数据类型需要与input满足数据类型推导规则(参见互推导关系和约束说明)。 FLOAT16、BFLOAT16 ND [hiddenSize] - quantScale0 输入 首次RmsNorm公式中量化缩放的参数。 数据类型需要与input满足数据类型推导规则(参见互推导关系和约束说明)。 FLOAT16、BFLOAT16 ND [1] - quantOffset0 输入 首次RmsNorm公式中的量化偏移参数。 数据类型需要与input满足数据类型推导规则(参见互推导关系和约束说明)。 INT8 ND [1] - wdqkv 输入 与输入首次做矩阵乘的降维矩阵。 - INT8、FLOAT16、BFLOAT16 NZ [qLoraDim + keyTotalDim,hiddenSize] - deScale0 输入 输入首次做矩阵乘的降维矩阵中的系数。 input输入dtype为FLOAT16支持INT64,输入BFLOAT16时支持FLOAT。 INT32、FLOAT ND [qLoraDim + keyTotalDim] - bias0 输入 输入首次做矩阵乘的降维矩阵中的系数。 支持传入空tensor,quantMode为1、3时不传入。 INT32 ND [qLoraDim + keyTotalDim] - gamma1 输入 第二次RmsNorm计算中的γ参数。 数据类型需要与input满足数据类型推导规则(参见互推导关系和约束说明)。 FLOAT16、BFLOAT16 ND [qLoraDim] - beta1 输入 第二次RmsNorm计算中的β参数。 数据类型需要与input满足数据类型推导规则(参见互推导关系和约束说明)。 FLOAT16、BFLOAT16 ND [qLoraDim] - quantScale1 输入 第二次RmsNorm公式中量化缩放的参数。 数据类型需要与input满足数据类型推导规则(参见互推导关系和约束说明)。 FLOAT16、BFLOAT16 ND [1] - quantOffset1 输入 第二次RmsNorm公式中的量化偏移参数。 数据类型需要与input满足数据类型推导规则(参见互推导关系和约束说明)。 INT8 ND [1] - wuq 输入 权重矩阵。 - INT8、FLOAT16、BFLOAT16 NZ [headNum * (qNoRopeDim + qRopeDim),qLoraDim] - deScale1 输入 参与wuq矩阵乘的系数。 input输入dtype为FLOAT16支持INT64,输入BFLOAT16时支持FLOAT。 INT64、FLOAT ND [headNum * (qNoRopeDim + qRopeDim)] - bias1 输入 参与wuq矩阵乘的系数。 quantMode为1、3时不传入。 INT32 ND [headNum * (qNoRopeDim + qRopeDim)] - gamma2 输入 参与RmsNormAndreshapeAndCache计算的γ参数。 数据类型需要与input满足数据类型推导规则(参见互推导关系和约束说明)。 FLOAT16、BFLOAT16 ND [512] - cos 输入 用于计算旋转位置编码的正弦参数矩阵。 - FLOAT16、BFLOAT16 ND [tokenNum,64] - sin 输入 用于计算旋转位置编码的余弦参数矩阵。 - FLOAT16、BFLOAT16 ND [tokenNum,64] - wuk 输入 表示计算Key的上采样权重。 - FLOAT16、BFLOAT16 ND [headNum,qNoRopeDim,512] - kvCache 输入 与输出的kvCacheOut为同一tensor。 输入格式随cacheMode变化: INT8、FLOAT16、BFLOAT16 ND、NZ - - kvCacheRope 输入 与输出的krCacheOut为同一tensor。 可选参数,支出传入空指针,输入格式随cacheMode变化: FLOAT16、BFLOAT16 ND、NZ - - slotMapping 输入 表示用于存储kv_cache和kr_cache的索引。 - INT32 ND [tokenNum] - ctkvScale 输入 输出量化处理中参与计算的系数 仅在cacheMode为2时传入。 FLOAT16、BFLOAT16 ND [1] - qNopeScale 输入 输出量化处理中参与计算的系数。 仅在cacheMode为2时传入 FLOAT16、BFLOAT16 ND [1] - wdqDim 输入 表示经过matmul后拆分的dim大小。 预留参数,目前只支持1536。 int64_t - - - qRopeDim 输入 表示q传入rope的dim大小。 预留参数,目前只支持64。 int64_t - - - kRopeDim 输入 表示k传入rope的dim大小。 预留参数,目前只支持64。 int64_t - - - epsilon 输入 表示加在分母上防止除0。 - double - - - qRotaryCoeff 输入 表示q旋转系数。 预留参数,目前只支持2。 int64_t - - - kRotaryCoeff 输入 表示k旋转系数。 预留参数,目前只支持2。 int64_t - - - transposeWdq 输入 表示wdq是否转置。 预留参数,目前只支持true。 bool - - - transposeWuq 输入 表示wuq是否转置。 预留参数,目前只支持true。 bool - - - transposeWuk 输入 表示wuk是否转置。 预留参数,目前只支持true。 bool - - - cacheMode 输入 表示指定cache的类型。 - 0:kcache和q均经过拼接后输出。
- 1:输出的kvCacheOut拆分为kvCacheOut和krCacheOut,qOut拆分为qOut和qRopeOut。
- 2:krope和ctkv转为NZ格式输出,ctkv和qnope经过per_head静态对称量化为int8类型。
- 3:krope和ctkv转为NZ格式输出。
int64_t - - - quantMode 输入 表示指定RmsNorm量化的类型。 - 0:per_tensor静态非对称量化,默认量化类型。
- 1:per_token动态对称量化,未实现。
- 2:per_token动态非对称量化,未实现。
- 3:不量化,浮点输出,未实现。
int64_t - - - doRmsNorm 输入 控制对输入tensor做RmsNormQuant或者做Quant。 - false:输入tensor只做Quant不做RmsNorm
- true:输入tensor做RmsNormQuant操作。
bool - - - wdkvSplitCount 输入 表示指定wdkv拆分的个数。 支持[1-3],分别表示不拆分、拆分为2个、拆分为3个降维矩阵。预留参数,目前只支持1。 int64_t - - - qDownOutFlag 输入 表示是否输出qDownOut。 false表示不输出,true表示输出。 bool - - - qOut 输出 表示Query的输出tensor,对应计算流图中右侧经过NOPE和矩阵乘后的输出。 shape和dtype随cacheMode变化: INT8、FLOAT16、BFLOAT16 ND - - kvCacheOut 输出 表示Key经过ReshapeAndCache后的输出。 shape和dtype随cacheMode变化: INT8、FLOAT16、BFLOAT16 ND、NZ - - qRopeOut 输出 表示Query经过旋转编码后的输出。 shape和dtype随cacheMode变化: FLOAT16、BFLOAT16 ND - - krCacheOut 输出 表示Key经过ROPE和ReshapeAndCache后的输出。 shape和dtype随cacheMode变化: FLOAT16、BFLOAT16 ND、NZ - - qDownOut 输出 表示Query经过降维后的输出。 - FLOAT16、BFLOAT16 ND [tokenNum, 1536] - workspaceSize 输出参数 返回需要在Device侧申请的workspace大小。 - - - - - executor 输出参数 返回op执行器,包含了算子计算流程。 - - - - - -
返回值
aclnnStatus:返回状态码,具体参见aclnn返回码。
第一段接口完成入参校验,出现以下场景时报错:
返回值 错误码 描述 ACLNN_ERR_PARAM_NULLPTR 161001 必须传入的参数中存在空指针。 ACLNN_ERR_PARAM_INVALID 161002 输入参数的shape、dtype和数据类型不在支持的范围之内。 ACLNN_ERR_RUNTIME_ERROR 361001 API内存调用npu runtime的接口异常。 ACLNN_ERR_INNER_TILING_ERROR 561002 tiling发生异常,入参的dtype类型或者shape错误。
aclnnMlaPreprocessV2
-
参数说明
参数名 输入/输出 描述 workspace 输入 在Device侧申请的workspace内存地址。 workspaceSize 输入 在Device侧申请的workspace大小,由第一段接口aclnnBatchMatMulGetWorkspaceSize获取。 executor 输入 op执行器,包含了算子计算流程。 stream 输入 指定执行任务的Stream。 -
返回值
aclnnStatus:返回状态码,具体参见aclnn返回码。
约束说明
- 确定性计算:
- aclnnMlaPreprocessV2默认确定性实现。
- shape格式字段含义及约束
- tokenNum:tokenNum 表示输入样本批量大小,取值范围:0~256
- hiddenSize:hiddenSize 表示隐藏层的大小,取值固定为:2048~10240,为256的倍数
- headNum:表示多头数,取值范围:1~128
- blockNum:PagedAttention场景下的块数,取值范围:192
- blockSize:PagedAttention场景下的块大小,取值范围:128
- qloraDim:表示Q矩阵的LoRA输入维度,取值范围:32~4096,为32的倍数
- keyTotalDim:表示Key部分的总维度,取值固定为:576(512主维度+64 rope维度)
- qRopeDim:表示Q矩阵中旋转编码部分的维度,取值固定为:64
- qNoRopeDim:表示Q矩阵中无旋转编码部分的维度,取值范围:16~256,为16的倍数
- rope模式约束
- mla_preprocess 算子中的 Rotary Embedding(RoPE)操作采用 half 模式,暂不支持 interleave 模式
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。
/**
* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2025 Huawei Technologies Co., Ltd.|Hisilicon Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
/*!
* \file test_aclnn_mla_preprocess_v2.cpp
* \brief
*/
#include <iostream>
#include <vector>
#include <sys/stat.h>
#include <fstream>
#include <fcntl.h>
#include <unistd.h>
#include <cstdio>
#include <cassert>
#include <iomanip>
#include <unistd.h>
#include "acl/acl.h"
#include "aclnn/acl_meta.h"
#include "aclnnop/aclnn_mla_preprocess.h"
#define CHECK_RET(cond, return_expr) \
do { \
if (!(cond)) { \
return_expr; \
} \
} while (0)
#define LOG_PRINT(message, ...) \
do { \
printf(message, ##__VA_ARGS__); \
} while (0)
template <typename T>
bool ReadFile(const std::string &filePath, std::vector<int64_t> shape, std::vector<T>& hostData)
{
size_t fileSize = 1;
for (int64_t i : shape){
fileSize *= i;
}
std::ifstream file(filePath, std::ios::binary);
if (!file.is_open()) {
std::cerr << "无法打开文件" << std::endl;
return 1;
}
// 获取文件大小
file.seekg(0, std::ios::end);
file.seekg(0, std::ios::beg);
hostData.reserve(fileSize);
if (file.read(reinterpret_cast<char*>(hostData.data()), fileSize * sizeof(T))) {
} else {
std::cerr << "读取文件失败" << std::endl;
return 1;
}
file.close();
return true;
}
template <typename T>
bool WriteFile(const std::string &filePath, int64_t size, std::vector<T>& hostData)
{
int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE);
if (fd < 0) {
LOG_PRINT("Open file failed. path = %s", filePath.c_str());
return false;
}
size_t writeSize = write(fd, reinterpret_cast<char*>(hostData.data()), size * sizeof(T));
(void)close(fd);
if (writeSize != size * sizeof(T)) {
LOG_PRINT("Write file Failed.");
return false;
}
return true;
}
int64_t GetShapeSize(const std::vector<int64_t>& shape)
{
int64_t shapeSize = 1;
for (auto i : shape) {
shapeSize *= i;
}
return shapeSize;
}
void PrintOutResult(std::vector<int64_t>& shape, void** deviceAddr, int num)
{
auto size = GetShapeSize(shape);
std::vector<float> resultData(size, 0);
auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), *deviceAddr,
size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return);
for (int64_t i = 0; i < 10; i++) {
LOG_PRINT("result[%ld] is: %f\n", i, resultData[i]);
}
}
int Init(int32_t deviceId, aclrtStream *stream) {
// 固定写法,资源初始化
auto ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret);
return ret);
ret = aclrtSetDevice(deviceId);
CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret);
return ret);
ret = aclrtCreateStream(stream);
CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret);
return ret);
return 0;
}
template <typename T>
int CreateAclTensor(const std::vector<T> &hostData,
const std::vector<int64_t> &shape, void **deviceAddr,
aclDataType dataType, aclTensor **tensor) {
auto size = GetShapeSize(shape) * sizeof(T);
// 调用aclrtMalloc申请device侧内存
auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret);
return ret);
// 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size,
ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret);
return ret);
// 计算连续tensor的strides
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; i--) {
strides[i] = shape[i + 1] * strides[i + 1];
}
// 调用aclCreateTensor接口创建aclTensor
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType,
strides.data(), 0, aclFormat::ACL_FORMAT_ND,
shape.data(), shape.size(), *deviceAddr);
return 0;
}
template <typename T>
int CreateAclTensorND(const std::vector<T>& shape, void** deviceAddr, void** hostAddr,
aclDataType dataType, aclTensor** tensor) {
auto size = GetShapeSize(shape) * sizeof(T);
// 调用aclrtMalloc申请device侧内存
auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc ND tensor device failed. ERROR: %d\n", ret); return ret);
// 调用aclrtMalloc申请host侧内存
ret = aclrtMalloc(hostAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc ND tensor host failed. ERROR: %d\n", ret); return ret);
// 调用aclCreateTensor接口创建aclTensor
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType, nullptr, 0, aclFormat::ACL_FORMAT_ND,
shape.data(), shape.size(), *deviceAddr);
// 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
ret = aclrtMemcpy(*deviceAddr, size, *hostAddr, GetShapeSize(shape)*aclDataTypeSize(dataType), ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
return 0;
}
template <typename T>
int CreateAclTensorNZ(const std::vector<T>& shape, void** deviceAddr, void** hostAddr,
aclDataType dataType, aclTensor** tensor) {
auto size = GetShapeSize(shape) * sizeof(T);
// 调用aclrtMalloc申请device侧内存
auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc NZ tensor device failed. ERROR: %d\n", ret); return ret);
// 调用aclrtMalloc申请host侧内存
ret = aclrtMalloc(hostAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc NZ tensor device failed. ERROR: %d\n", ret); return ret);
// 调用aclCreateTensor接口创建aclTensor
*tensor = aclCreateTensor(shape.data(), shape.size (), dataType, nullptr, 0, aclFormat::ACL_FORMAT_FRACTAL_NZ,
shape.data(), shape.size (), *deviceAddr);
// 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
ret = aclrtMemcpy(*deviceAddr, size, *hostAddr, GetShapeSize(shape)*aclDataTypeSize(dataType), ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
return 0;
}
int TransToNZShape(std::vector<int64_t> &shapeND, size_t typeSize) {
int64_t h = shapeND[0];
int64_t w = shapeND[1];
int64_t h0 = 16;
int64_t w0 = 32U / typeSize;
int64_t h1 = h / h0;
int64_t w1 = w / w0;
shapeND[0] = w1;
shapeND[1] = h1;
shapeND.emplace_back(h0);
shapeND.emplace_back(w0);
return 0;
}
int main() {
// 1. (固定写法)device/stream初始化,acl API手册
// 根据自己的实际device填写deviceId
int32_t deviceId = 5;
aclrtStream stream;
auto ret = Init(deviceId, &stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret);
return ret);
//属性
int64_t tokenNum = 8;
int64_t hiddenNum = 7168;
int64_t headNum = 32;
int64_t blockNum = 192;
int64_t blockSize = 128;
int64_t wdqDim = 128;
int64_t qRopeDim = 0;
int64_t kRopeDim = 0;
double epsilon = 1e-05;
int64_t qRotaryCoeff = 2;
int64_t kRotaryCoeff = 2;
bool transposeWdq = true;
bool transposeWuq = true;
bool transposeWuk = true;
int64_t cacheMode = 1;
int64_t quantMode = 0;
bool doRmsNorm = true;
int64_t wdkvSplitCount = 1;
// 2. 构造输入与输出,需要根据API的接口自定义构造
std::vector<int64_t> inputShape = {tokenNum, hiddenNum};
std::vector<int64_t> gamma0Shape = {hiddenNum};
std::vector<int64_t> beta0Shape = {hiddenNum};
std::vector<int64_t> quantScale0Shape = {1};
std::vector<int64_t> quantOffset0Shape = {1};
std::vector<int64_t> wdqkvShape = {2112, hiddenNum};
std::vector<int64_t> deScale0Shape = {2112};
std::vector<int64_t> bias0Shape = {2112};
std::vector<int64_t> gamma1Shape = {1536};
std::vector<int64_t> beta1Shape = {1536};
std::vector<int64_t> quantScale1Shape = {1};
std::vector<int64_t> quantOffset1Shape = {1};
std::vector<int64_t> wuqShape = {headNum * 192, 1536};
std::vector<int64_t> deScale1Shape = {headNum * 192};
std::vector<int64_t> bias1Shape = {headNum * 192};
std::vector<int64_t> gamma2Shape = {512};
std::vector<int64_t> cosShape = {tokenNum, 64};
std::vector<int64_t> sinShape = {tokenNum, 64};
std::vector<int64_t> wukShape = {headNum, 128, 512};
std::vector<int64_t> kvCacheShape = {blockNum, blockSize, 1, 576};
std::vector<int64_t> kvCacheRopeShape = {blockNum, blockSize, 1, 64};
std::vector<int64_t> slotMappingShape = {tokenNum};
std::vector<int64_t> ctkvScaleShape = {1};
std::vector<int64_t> qNopeScaleShape = {headNum};
std::vector<int64_t> qOutShape = {tokenNum, headNum, 576};
std::vector<int64_t> kvCacheOutShape = {blockNum, blockSize, 1, 576};
std::vector<int64_t> qRopeOutShape = {tokenNum, headNum, 64};
std::vector<int64_t> krCacheOutShape = {blockNum, blockSize, 1, 64};
void* inputDeviceAddr = nullptr;
void* gamma0DeviceAddr = nullptr;
void* beta0DeviceAddr = nullptr;
void* quantScale0DeviceAddr = nullptr;
void* quantOffset0DeviceAddr = nullptr;
void* wdqkvDeviceAddr = nullptr;
void* deScale0DeviceAddr = nullptr;
void* bias0DeviceAddr = nullptr;
void* gamma1DeviceAddr = nullptr;
void* beta1DeviceAddr = nullptr;
void* quantScale1DeviceAddr = nullptr;
void* quantOffset1DeviceAddr = nullptr;
void* wuqDeviceAddr = nullptr;
void* deScale1DeviceAddr = nullptr;
void* bias1DeviceAddr = nullptr;
void* gamma2DeviceAddr = nullptr;
void* cosDeviceAddr = nullptr;
void* sinDeviceAddr = nullptr;
void* wukDeviceAddr = nullptr;
void* kvCacheDeviceAddr = nullptr;
void* kvCacheRopeDeviceAddr = nullptr;
void* slotMappingDeviceAddr = nullptr;
void* ctkvScaleDeviceAddr = nullptr;
void* qNopeScaleDeviceAddr = nullptr;
void* qOutDeviceAddr = nullptr;
void* kvCacheOutDeviceAddr = nullptr;
void* qRopeOutDeviceAddr = nullptr;
void* krCacheOutDeviceAddr = nullptr;
void* inputHostAddr = nullptr;
void* gamma0HostAddr = nullptr;
void* beta0HostAddr = nullptr;
void* quantScale0HostAddr = nullptr;
void* quantOffset0HostAddr = nullptr;
void* wdqkvHostAddr = nullptr;
void* deScale0HostAddr = nullptr;
void* bias0HostAddr = nullptr;
void* gamma1HostAddr = nullptr;
void* beta1HostAddr = nullptr;
void* quantScale1HostAddr = nullptr;
void* quantOffset1HostAddr = nullptr;
void* wuqHostAddr = nullptr;
void* deScale1HostAddr = nullptr;
void* bias1HostAddr = nullptr;
void* gamma2HostAddr = nullptr;
void* cosHostAddr = nullptr;
void* sinHostAddr = nullptr;
void* wukHostAddr = nullptr;
void* kvCacheHostAddr = nullptr;
void* kvCacheRopeHostAddr = nullptr;
void* slotMappingHostAddr = nullptr;
void* ctkvScaleHostAddr = nullptr;
void* qNopeScaleHostAddr = nullptr;
void* qOutHostAddr = nullptr;
void* kvCacheOutHostAddr = nullptr;
void* qRopeOutHostAddr = nullptr;
void* krCacheOutHostAddr = nullptr;
aclTensor* input = nullptr;
aclTensor* gamma0 = nullptr;
aclTensor* beta0 = nullptr;
aclTensor* quantScale0 = nullptr;
aclTensor* quantOffset0 = nullptr;
aclTensor* wdqkv = nullptr;
aclTensor* deScale0 = nullptr;
aclTensor* bias0 = nullptr;
aclTensor* gamma1 = nullptr;
aclTensor* beta1 = nullptr;
aclTensor* quantScale1 = nullptr;
aclTensor* quantOffset1 = nullptr;
aclTensor* wuq = nullptr;
aclTensor* deScale1 = nullptr;
aclTensor* bias1 = nullptr;
aclTensor* gamma2 = nullptr;
aclTensor* cos = nullptr;
aclTensor* sin = nullptr;
aclTensor* wuk = nullptr;
aclTensor* kvCache = nullptr;
aclTensor* kvCacheRope = nullptr;
aclTensor* slotMapping = nullptr;
aclTensor* ctkvScale = nullptr;
aclTensor* qNopeScale = nullptr;
aclTensor* qOut = nullptr;
aclTensor* kvCacheOut = nullptr;
aclTensor* qRopeOut = nullptr;
aclTensor* krCacheOut = nullptr;
// 转换三个NZ格式变量的shape
ret = TransToNZShape(wdqkvShape, sizeof(int8_t));
CHECK_RET(ret == 0, LOG_PRINT("trans NZ shape failed. \n"); return ret);
ret = TransToNZShape(wuqShape, sizeof (int8_t));
CHECK_RET(ret == 0, LOG_PRINT("trans NZ shape failed. \n"); return ret);
ret = CreateAclTensorND(inputShape, &inputDeviceAddr, &inputHostAddr, aclDataType::ACL_FLOAT16, &input);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(gamma0Shape, &gamma0DeviceAddr, &gamma0HostAddr, aclDataType::ACL_FLOAT16, &gamma0);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(beta0Shape, &beta0DeviceAddr, &beta0HostAddr, aclDataType::ACL_FLOAT16, &beta0);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(quantScale0Shape, &quantScale0DeviceAddr, &quantScale0HostAddr, aclDataType::ACL_FLOAT16, &quantScale0);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(quantOffset0Shape, &quantOffset0DeviceAddr, &quantOffset0HostAddr, aclDataType::ACL_INT8, &quantOffset0);
CHECK_RET(ret == ACL_SUCCESS, return ret);
//wdqkv转为NZ
ret = CreateAclTensorNZ(wdqkvShape, &wdqkvDeviceAddr, &wdqkvHostAddr, aclDataType::ACL_INT8, &wdqkv);
CHECK_RET(ret == ACL_SUCCESS, return ret);
//fp16输入,则这里转为int64
ret = CreateAclTensorND(deScale0Shape, &deScale0DeviceAddr, &deScale0HostAddr, aclDataType::ACL_INT64, &deScale0);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(bias0Shape, &bias0DeviceAddr, &bias0HostAddr, aclDataType::ACL_INT32, &bias0);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(gamma1Shape, &gamma1DeviceAddr, &gamma1HostAddr, aclDataType::ACL_FLOAT16, &gamma1);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(beta1Shape, &beta1DeviceAddr, &beta1HostAddr, aclDataType::ACL_FLOAT16, &beta1);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(quantScale1Shape, &quantScale1DeviceAddr, &quantScale1HostAddr, aclDataType::ACL_FLOAT16, &quantScale1);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(quantOffset1Shape, &quantOffset1DeviceAddr, &quantOffset1HostAddr, aclDataType::ACL_INT8, &quantOffset1);
CHECK_RET(ret == ACL_SUCCESS, return ret);
//wuq转为NZ
ret = CreateAclTensorNZ(wuqShape, &wuqDeviceAddr, &wuqHostAddr, aclDataType::ACL_INT8, &wuq);
CHECK_RET(ret == ACL_SUCCESS, return ret);
//fp16输入,则这里转为int64
ret = CreateAclTensorND(deScale1Shape, &deScale1DeviceAddr, &deScale1HostAddr, aclDataType::ACL_INT64, &deScale1);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(bias1Shape, &bias1DeviceAddr, &bias1HostAddr, aclDataType::ACL_INT32, &bias1);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(gamma2Shape, &gamma2DeviceAddr, &gamma2HostAddr, aclDataType::ACL_FLOAT16, &gamma2);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(cosShape, &cosDeviceAddr, &cosHostAddr, aclDataType::ACL_FLOAT16, &cos);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(sinShape, &sinDeviceAddr, &sinHostAddr, aclDataType::ACL_FLOAT16, &sin);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(wukShape, &wukDeviceAddr, &wukHostAddr, aclDataType::ACL_FLOAT16, &wuk);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(kvCacheShape, &kvCacheDeviceAddr, &kvCacheHostAddr, aclDataType::ACL_FLOAT16, &kvCache);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(kvCacheRopeShape, &kvCacheRopeDeviceAddr, &kvCacheRopeHostAddr, aclDataType::ACL_FLOAT16, &kvCacheRope);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(slotMappingShape, &slotMappingDeviceAddr, &slotMappingHostAddr, aclDataType::ACL_INT32, &slotMapping);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(ctkvScaleShape, &ctkvScaleDeviceAddr, &ctkvScaleHostAddr, aclDataType::ACL_FLOAT16, &ctkvScale);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(qNopeScaleShape, &qNopeScaleDeviceAddr, &qNopeScaleHostAddr, aclDataType::ACL_FLOAT16, &qNopeScale);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(qOutShape, &qOutDeviceAddr, &qOutHostAddr, aclDataType::ACL_FLOAT16, &qOut);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(kvCacheOutShape, &kvCacheOutDeviceAddr, &kvCacheOutHostAddr, aclDataType::ACL_FLOAT16, &kvCacheOut);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(qRopeOutShape, &qRopeOutDeviceAddr, &qRopeOutHostAddr, aclDataType::ACL_FLOAT16, &qRopeOut);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensorND(krCacheOutShape, &krCacheOutDeviceAddr, &krCacheOutHostAddr, aclDataType::ACL_FLOAT16, &krCacheOut);
CHECK_RET(ret == ACL_SUCCESS, return ret);
// 3. 调用CANN算子库API,需要修改为具体的API名称
uint64_t workspaceSize = 0;
aclOpExecutor *executor;
// 调用acaclnnMlaPreprocess第一段接口
ret = aclnnMlaPreprocessGetWorkspaceSize(
input, gamma0, beta0, quantScale0, quantOffset0,
wdqkv, deScale0, bias0, gamma1, beta1, quantScale1, quantOffset1, wuq, deScale1, bias1, gamma2, cos, sin, wuk, kvCache, kvCacheRope, slotMapping, ctkvScale, qNopeScale,
wdqDim, qRopeDim, kRopeDim, epsilon, qRotaryCoeff, kRotaryCoeff, transposeWdq, transposeWuq, transposeWuk, cacheMode, quantMode, doRmsNorm, wdkvSplitCount, qOut, kvCacheOut, qRopeOut, krCacheOut, &workspaceSize, &executor);
CHECK_RET(
ret == ACL_SUCCESS,
LOG_PRINT("acaclnnMlaPreprocessGetWorkspaceSize failed. ERROR: %d\n", ret);
return ret);
// 根据第一段接口计算出的workspaceSize申请device内存
void *workspaceAddr = nullptr;
if (workspaceSize > 0) {
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret);
return ret);
}
// 调用acaclnnMlaPreprocess第二段接口
ret = aclnnMlaPreprocess(workspaceAddr, workspaceSize, executor, stream);
CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("acaclnnMlaPreprocess failed. ERROR: %d\n", ret);
return ret);
// 4. (固定写法)同步等待任务执行结束
ret = aclrtSynchronizeStream(stream);
CHECK_RET(ret == ACL_SUCCESS,
LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret);
return ret);
// 5.获取输出的值,将device侧内存上的结果拷贝至host侧,需要根据具体API的接口定义修改
auto qOutSize = GetShapeSize(qOutShape);
std::vector<float> qOutData(qOutSize, 0);
ret = aclrtMemcpy(qOutData.data(), qOutData.size() * sizeof(qOutData[0]), qOutDeviceAddr, qOutSize * sizeof(float),
ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret);
// 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改
// 释放aclTensor资源
aclDestroyTensor(input);
aclDestroyTensor(gamma0);
aclDestroyTensor(beta0);
aclDestroyTensor(quantScale0);
aclDestroyTensor(quantOffset0);
aclDestroyTensor(wdqkv);
aclDestroyTensor(deScale0);
aclDestroyTensor(bias0);
aclDestroyTensor(gamma1);
aclDestroyTensor(beta1);
aclDestroyTensor(quantScale1);
aclDestroyTensor(quantOffset1);
aclDestroyTensor(wuq);
aclDestroyTensor(deScale1);
aclDestroyTensor(bias1);
aclDestroyTensor(gamma2);
aclDestroyTensor(cos);
aclDestroyTensor(sin);
aclDestroyTensor(wuk);
aclDestroyTensor(kvCache);
aclDestroyTensor(kvCacheRope);
aclDestroyTensor(slotMapping);
aclDestroyTensor(ctkvScale);
aclDestroyTensor(qNopeScale);
// 7. 释放device 资源
aclrtFree(inputDeviceAddr);
aclrtFree(gamma0DeviceAddr);
aclrtFree(beta0DeviceAddr);
aclrtFree(quantScale0DeviceAddr);
aclrtFree(quantOffset0DeviceAddr);
aclrtFree(wdqkvDeviceAddr);
aclrtFree(deScale0DeviceAddr);
aclrtFree(bias0DeviceAddr);
aclrtFree(gamma1DeviceAddr);
aclrtFree(beta1DeviceAddr);
aclrtFree(quantScale1DeviceAddr);
aclrtFree(quantOffset1DeviceAddr);
aclrtFree(wuqDeviceAddr);
aclrtFree(deScale1DeviceAddr);
aclrtFree(bias1DeviceAddr);
aclrtFree(gamma2DeviceAddr);
aclrtFree(cosDeviceAddr);
aclrtFree(sinDeviceAddr);
aclrtFree(wukDeviceAddr);
aclrtFree(kvCacheDeviceAddr);
aclrtFree(kvCacheRopeDeviceAddr);
aclrtFree(slotMappingDeviceAddr);
aclrtFree(ctkvScaleDeviceAddr);
aclrtFree(qNopeScaleDeviceAddr);
// 8. 释放host 资源
aclrtFree(inputHostAddr);
aclrtFree(gamma0HostAddr);
aclrtFree(beta0HostAddr);
aclrtFree(quantScale0HostAddr);
aclrtFree(quantOffset0HostAddr);
aclrtFree(wdqkvHostAddr);
aclrtFree(deScale0HostAddr);
aclrtFree(bias0HostAddr);
aclrtFree(gamma1HostAddr);
aclrtFree(beta1HostAddr);
aclrtFree(quantScale1HostAddr);
aclrtFree(quantOffset1HostAddr);
aclrtFree(wuqHostAddr);
aclrtFree(deScale1HostAddr);
aclrtFree(bias1HostAddr);
aclrtFree(gamma2HostAddr);
aclrtFree(cosHostAddr);
aclrtFree(sinHostAddr);
aclrtFree(wukHostAddr);
aclrtFree(kvCacheHostAddr);
aclrtFree(kvCacheRopeHostAddr);
aclrtFree(slotMappingHostAddr);
aclrtFree(ctkvScaleHostAddr);
aclrtFree(qNopeScaleHostAddr);
if (workspaceSize > 0) {
aclrtFree(workspaceAddr);
}
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
return 0;
}