aclnnLightingIndexerV2Metadata
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | x |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | x |
| Atlas 200I/500 A2 推理产品 | x |
| Atlas 推理系列产品 | x |
| Atlas 训练系列产品 | x |
功能说明
-
接口功能:该接口为AI CPU算子接口,是aclnnLightningIndexerV2算子的前置算子接口。根据aclnnLightningIndexerV2算子接口的输入信息,计算并输出负载均衡结果。输出结果可以作为aclnnLightningIndexerV2算子接口的输入,减少aclnnLightningIndexerV2算子接口的执行耗时。
该算子不建议单独使用,建议与aclnnLightningIndexerV2算子配合使用,形成完整的工作流。
- 接受aclnnLightningIndexerV2算子接口输入数据shape信息,包含batchSize、qSeqlen、kSeqlen、mask。通过对输入分块并模拟计算耗时,均匀分配分块到可用核上,以降低aclnnLightningIndexerV2算子的整体计算耗时,并提高硬件利用率。
- 分配结果输出后,后续作为输入供aclnnLightningIndexerV2算子使用。
- 分配结果包含每个AIC核基本块的起始点和终止点,已经每个AIV核的FD任务信息。详细内容可以参考调用示例。
函数原型
每个算子分为两段式接口,必须先调用"aclnnLightingIndexerV2MetadataGetWorkspaceSize"获取workspace大小,在调用"aclnnLightingIndexerV2Metadata"执行计算
aclnnStatus aclnnLightningIndexerV2MetadataGetWorkspaceSize(
const aclTensor *cuSeqlensQOptional,
const aclTensor *cuSeqlensKOptional,
const aclTensor *sequsedQOptional,
const aclTensor *sequsedKOptional,
const aclTensor *cmpResidualKOptional,
int64_t numHeadsQ,
int64_t numHeadsK,
int64_t headDim,
int64_t topk,
int64_t batchSize,
int64_t maxSeqlenQ,
int64_t maxSeqlenK,
char *layoutQOptional,
char *layoutKOptional,
int64_t maskMode,
int64_t cmpRatio,
const aclTensor *metaData,
uint64_t *workspaceSize,
aclOpExecutor **executor)
aclnnStatus aclnnLightningIndexerV2Metadata(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream)
aclnnLightingIndexerV2MetadataGetWorkspaceSize
-
参数说明
参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor cuSeqlensQOptional 输入 表示不同Batch中Query的有效Sequence Length
TND场景下必传,以该入参的数量作为Batch值
第一个值为额外值并固定为0。支持空Tensor INT32 ND 1维,shape固定为(B+1,) √ cuSeqlensKOptional 输入 表示不同Batch中Key的有效Sequence Length
TND场景下必传,以该入参的数量作为Batch值
第一个值为额外值并固定为0。支持空Tensor INT32 ND 1维,shape固定为(B+1,) √ sequsedQOptional 输入 表示不同Batch中Query实际参与运算的Sequence Length。 支持空Tensor INT32 ND 1维,shape固定为(B,) √ sequsedKOptional 输入 表示不同Batch中Key实际参与运算的Sequence Length。 支持空Tensor INT32 ND 1维,shape固定为(B,) √ cmpResidualKOptional 输入 预留参数,表示不同Batch中Key的Sequence Length的余数,当前不影响算子计算效果。
如果cmpRatio不为1,且mask为3,则必须传入cmpResidualKOptional。支持空Tensor INT32 ND 1维,shape固定为(B,) √ numHeadsQ 输入 表示Query的head个数。 支持非负数 INT64 - - - numHeadsK 输入 表示Key的head个数。 支持非负数 INT64 - - - headDim 输入 表示token数。 支持非负数 INT64 - - - topk 输入 表示从Query中筛选出的关键稀疏token的个数。 支持非负数 INT64 - - - batchSize 输入 表示Batch数量。 支持非负数
建议值为0INT64 - - - maxSeqlenQ 输入 表示Query的最长Sequence Length。 支持非负数
建议值为0INT64 - - - maxSeqlenK 输入 表示Key的最长Sequence Length。 支持非负数
建议值为0INT64 - - - layoutQOptional 输入 表示Query的排列格式。 支持 BSND、TND
建议值为BSNDINT64 - - - layoutKOptional 输入 表示Key的排列格式。 支持 BSND、TND、PA_BBND
建议值为BSNDINT64 - - - maskMode 输入 表示sparse模式。 0: No mask
3: rightDownCausal模式的mask,对应以右顶点为划分的下三角场景
建议值为0INT64 - - - cmpRatio 输入 预留参数,表示Key的压缩率,当前不影响算子计算效果。
建议值1,表示无压缩。取值范围[1,128] INT64 - - - metaData 输出 表示负载均衡结果输出。 - INT32 ND 1维,shape固定为(1024) × workspaceSize 输出 返回需要在Device侧申请的workspace大小。 - - - - - executor 输出 返回op执行器,包含了算子计算流程。 - - - - - -
返回值:
返回aclnnStatus状态码,具体参见aclnn返回码。
aclnnLightingIndexerV2Metadata
-
参数说明:
参数名 输入/输出 描述 workspace 输入 在Device侧申请的workspace内存地址 workspaceSize 输入 在Device侧申请的workspace大小,由第一段接口aclnnLightingIndexerV2MetadataGetWorkspaceSize获取 executor 输入 op执行器,包含了算子计算流程 stream 输入 指定执行任务的Stream -
返回值:
返回aclnnStatus状态码,具体参见aclnn返回码。
约束说明
-
aclnnLightingIndexerV2Metadata默认确定性实现
-
Batch取值规则
- 优先获取sequsedQOptional中的Batch信息
- 如果未传入sequsedQOptional,优先获取cuSeqlensQOptional中的Batch信息
- 如果未传入sequsedQOptional,且layoutQOptional为TND,则必获取cuSeqlensQOptional中的Batch信息
- 除上所述,使用batchSize
-
Sequence Length取值规则
- 优先获取sequsedQOptional中的Sequence Length信息
- 如果未传入sequsedQOptional,且layoutQOptional为TND,则必获取cuSeqlensQOptional中的Sequence Length信息
- 除上所述,使用maxSeqlenQ
- Key与Query的获取规则一致
-
layout约束
- 当layoutKOptional为PA_BBND时,layoutQOptional可以任意取值
- 除上所述,layoutQOptional必须与layoutKOptional保持一致
-
BSND场景
- 当传入的layoutQOptional为"BSND"时,视为使用BSND场景
- 在未传入cuSeqlensQOptional和sequsedQOptional的情况下,必传batchSize、maxSeqlenQ、maxSeqlenK参数
-
TND场景
- 当传入的layoutQOptional为"TND"时,视为使用TND场景
- 必传cuSeqlensQOptional、cuSeqlensKOptional参数
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。
#include <iostream>
#include <vector>
#include <cmath>
#include <cstring>
#include <limits>
#include <functional>
#include <utility>
#include "acl/acl.h"
#include "aclnnop/aclnn_lightning_indexer_v2_metadata.h"
#define CHECK_LOG_RET(cond, ret_val, fmt, ...) \
do { \
if (!(cond)) { \
printf(fmt "\n", ##__VA_ARGS__); \
return (ret_val); \
} \
} while (0)
// 参考 lightning_indexer_v2_metadata.h
constexpr uint32_t AIC_CORE_NUM = 36;
constexpr uint32_t AIV_CORE_NUM = 72;
constexpr uint32_t LI_V2_METADATA_SIZE = 8;
constexpr uint32_t LD_V2_METADATA_SIZE = 8;
// LI Metadata Index Definitions
constexpr uint32_t LI_V2_CORE_ENABLE_INDEX = 0;
constexpr uint32_t LI_V2_BN2_START_INDEX = 1;
constexpr uint32_t LI_V2_M_START_INDEX = 2;
constexpr uint32_t LI_V2_S2_START_INDEX = 3;
constexpr uint32_t LI_V2_BN2_END_INDEX = 4;
constexpr uint32_t LI_V2_M_END_INDEX = 5;
constexpr uint32_t LI_V2_S2_END_INDEX = 6;
constexpr uint32_t LI_V2_FIRST_LD_V2_DATA_WORKSPACE_IDX_INDEX = 7;
// LD Metadata Index Definitions
constexpr uint32_t LD_V2_CORE_ENABLE_INDEX = 0;
constexpr uint32_t LD_V2_BN2_IDX_INDEX = 1;
constexpr uint32_t LD_V2_M_IDX_INDEX = 2;
constexpr uint32_t LD_V2_WORKSPACE_IDX_INDEX = 3;
constexpr uint32_t LD_V2_WORKSPACE_NUM_INDEX = 4;
constexpr uint32_t LD_V2_M_START_INDEX = 5;
constexpr uint32_t LD_V2_M_NUM_INDEX = 6;
struct LiV2MetaData {
uint32_t faData[AIC_CORE_NUM][LI_V2_METADATA_SIZE];
uint32_t fdData[AIV_CORE_NUM][LD_V2_METADATA_SIZE];
};
struct ScopeGuard
{
explicit ScopeGuard(std::function<void()> onExitScope) : m_exitFunc(std::move(onExitScope)),
m_isDismissed(false) {}
// 禁止拷贝
ScopeGuard(const ScopeGuard&) = delete;
ScopeGuard& operator=(const ScopeGuard&) = delete;
~ScopeGuard()
{
if (!m_isDismissed) {
m_exitFunc();
}
}
void Dismiss()
{
m_isDismissed = true;
}
std::function<void()> m_exitFunc;
bool m_isDismissed;
};
struct Tensor {
void *hostAddr { nullptr };
void *deviceAddr { nullptr };
aclTensor *data { nullptr };
};
struct ArgScenario {
bool hasCuSeq { false };
bool hasSeqused { false };
};
struct ArgContext {
// required input
int64_t numHeadsQ { 0 };
int64_t numHeadsK { 0 };
int64_t headDim { 0 };
int64_t topk { 0 };
// optional input
Tensor cuSeqlensQOptional {};
Tensor cuSeqlensKOptional {};
Tensor sequsedQOptional {};
Tensor sequsedKOptional {};
Tensor cmpResidualKOptional {};
int64_t batchSize { 0 };
int64_t maxSeqlenQ { 0 };
int64_t maxSeqlenK { 0 };
char *layoutQOptional { nullptr };
char *layoutKOptional { nullptr };
int64_t maskMode { 0 };
int64_t cmpRatio { 0 };
// output
Tensor metaData {};
};
int64_t GetShapeSize(const std::vector<int64_t>& shape)
{
int64_t shapeSize = 1;
for (auto i : shape) {
shapeSize *= i;
}
return shapeSize;
}
aclnnStatus Init(int32_t deviceId, aclrtStream* stream)
{
// 固定写法,初始化
auto ret = aclInit(nullptr);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclInit failed. ERROR: %d", ret);
ret = aclrtSetDevice(deviceId);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtSetDevice failed. ERROR: %d", ret);
ret = aclrtCreateStream(stream);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtCreateStream failed. ERROR: %d", ret);
return ACL_SUCCESS;
}
void Finalize(int32_t deviceId, aclrtStream stream)
{
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
}
aclnnStatus CreateTensor(aclDataType dataType, const std::vector<int64_t> &shape, Tensor &tensor)
{
auto size = GetShapeSize(shape) * aclDataTypeSize(dataType);
// 调用aclrtMallocHost申请host侧内存
auto ret = aclrtMallocHost(&(tensor.hostAddr), size);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtMallocHost failed. ERROR: %d", ret);
memset(tensor.hostAddr, 0, size);
// 调用aclrtMalloc申请device侧内存
ret = aclrtMalloc(&(tensor.deviceAddr), size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtMalloc failed. ERROR: %d", ret);
// 调用aclCreateTensor接口创建aclTensor
tensor.data = aclCreateTensor(shape.data(), shape.size(), dataType, nullptr, 0, aclFormat::ACL_FORMAT_ND,
shape.data(), shape.size(), tensor.deviceAddr);
// 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
ret = aclrtMemcpy(tensor.deviceAddr, size, tensor.hostAddr, size, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtMemcpy failed. ERROR: %d", ret);
return ACL_SUCCESS;
}
void DestroyTensor(Tensor &tensor)
{
if (tensor.data != nullptr) {
aclDestroyTensor(tensor.data);
tensor.data = nullptr;
}
if (tensor.deviceAddr != nullptr) {
aclrtFree(tensor.deviceAddr);
tensor.deviceAddr = nullptr;
}
if (tensor.hostAddr != nullptr) {
aclrtFreeHost(tensor.hostAddr);
tensor.hostAddr = nullptr;
}
}
void DestroyArgs(ArgContext &context)
{
DestroyTensor(context.metaData);
DestroyTensor(context.cuSeqlensQOptional);
DestroyTensor(context.cuSeqlensKOptional);
DestroyTensor(context.sequsedQOptional);
DestroyTensor(context.sequsedKOptional);
DestroyTensor(context.cmpResidualKOptional);
if (context.layoutQOptional != nullptr) {
free(context.layoutQOptional);
context.layoutQOptional = nullptr;
}
if (context.layoutKOptional != nullptr) {
free(context.layoutKOptional);
context.layoutKOptional = nullptr;
}
}
aclnnStatus CreateArgs(const ArgScenario &scenario, ArgContext &context)
{
ScopeGuard argsGuard([&] { DestroyArgs(context); });
aclnnStatus ret;
int64_t batchSize = 4;
context.numHeadsQ = 1;
context.numHeadsK = 1;
context.headDim = 128;
context.topk = 0;
ret = CreateTensor(aclDataType::ACL_INT32, { 1024 }, context.metaData); // 1024: Fix size
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Create meta failed. Error: %d", ret);
context.maskMode = 0; // 0: no mask, 3: causal
context.cmpRatio = 1; // [1, 128], 1: no compress
context.layoutQOptional = (char *)malloc(sizeof(char) * 16);
context.layoutKOptional = (char *)malloc(sizeof(char) * 16);
strcpy(context.layoutQOptional, "BSND"); // BSND,TND
strcpy(context.layoutKOptional, "BSND"); // BSND,TND,PA_BBND
if (!scenario.hasCuSeq && !scenario.hasSeqused) {
context.batchSize = batchSize;
context.maxSeqlenK = 1024;
context.maxSeqlenQ = 1024;
return ACL_SUCCESS;
}
if (scenario.hasCuSeq) {
// (B+1,), first element is always 0
ret = CreateTensor(aclDataType::ACL_INT32, { batchSize + 1 }, context.cuSeqlensQOptional);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Create cuSeqlensQOptional failed. Error: %d", ret);
ret = CreateTensor(aclDataType::ACL_INT32, { batchSize + 1 }, context.cuSeqlensKOptional);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Create cuSeqlensKOptional failed. Error: %d", ret);
}
if (scenario.hasSeqused) {
// (B,)
ret = CreateTensor(aclDataType::ACL_INT32, { batchSize }, context.sequsedQOptional);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Create sequsedQOptional failed. Error: %d", ret);
ret = CreateTensor(aclDataType::ACL_INT32, { batchSize }, context.sequsedKOptional);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Create sequsedKOptional failed. Error: %d", ret);
}
argsGuard.Dismiss();
return ACL_SUCCESS;
}
int main() {
// 1. (固定写法)device/stream初始化,参考对外接口列表
// 根据自己的实际device填写deviceId
int32_t deviceId = 0;
aclrtStream stream;
auto ret = Init(deviceId, &stream);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Init acl failed. ERROR: %d", ret);
ScopeGuard sysGuard([&] { Finalize(deviceId, stream); });
// 2. 构造输入与输出,需要根据API的接口定义构造
ArgScenario scenario {};
scenario.hasCuSeq = true;
scenario.hasSeqused = true;
ArgContext context {};
ret = CreateArgs(scenario, context);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Create input arguments failed. ERROR: %d", ret);
ScopeGuard argsGuard([&] { DestroyArgs(context); });
// 3. 调用CANN算子库API,需要修改为具体的API
// 调用aclnnLightningIndexerV2Metadata第一段接口
uint64_t workspaceSize = 0;
aclOpExecutor *executor = nullptr;
void *workspaceAddr = nullptr;
ret = aclnnLightningIndexerV2MetadataGetWorkspaceSize(
context.cuSeqlensQOptional.data, context.cuSeqlensKOptional.data, context.sequsedQOptional.data,
context.sequsedKOptional.data, context.cmpResidualKOptional.data,
context.numHeadsQ, context.numHeadsK, context.headDim, context.topk,
context.batchSize, context.maxSeqlenQ, context.maxSeqlenK, context.layoutQOptional,
context.layoutKOptional, context.maskMode, context.cmpRatio,
context.metaData.data, &workspaceSize, &executor);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclnnLightningIndexerV2MetadataGetWorkspaceSize failed. ERROR: %d\n", ret);
if (workspaceSize > static_cast<uint64_t>(0)) {
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "allocate workspace failed. ERROR: %d\n", ret);
}
ScopeGuard workspaceGuard([&] {
if (workspaceAddr != nullptr) {
aclrtFree(workspaceAddr);
workspaceAddr = nullptr;
}
});
// 调用aclnnLightningIndexerV2Metadata第二段接口
ret = aclnnLightningIndexerV2Metadata(workspaceAddr, workspaceSize, executor, stream);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclnnLightningIndexerV2Metadata failed. ERROR: %d\n", ret);
// 4. (固定写法)同步等待任务执行结束
ret = aclrtSynchronizeStream(stream);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtSynchronizeStream failed. ERROR: %d\n", ret);
// 5. 打印输出
LiV2MetaData result {};
ret = aclrtMemcpy(&result, sizeof(result), context.metaData.deviceAddr, sizeof(result), ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtMemcpy failed. ERROR: %d\n", ret);
for (uint32_t i = 0; i < AIC_CORE_NUM; ++i) {
printf("AIC Core%u\n", i);
printf(" Core Enable : %u\n", result.faData[i][LI_V2_CORE_ENABLE_INDEX]);
printf(" Start BN2 : %u\n", result.faData[i][LI_V2_BN2_START_INDEX]);
printf(" Start M : %u\n", result.faData[i][LI_V2_M_START_INDEX]);
printf(" Start S2 : %u\n", result.faData[i][LI_V2_S2_START_INDEX]);
printf(" End BN2 : %u\n", result.faData[i][LI_V2_BN2_END_INDEX]);
printf(" End M : %u\n", result.faData[i][LI_V2_M_END_INDEX]);
printf(" End S2 : %u\n", result.faData[i][LI_V2_S2_END_INDEX]);
printf(" First Worksapce Index : %u\n", result.faData[i][LI_V2_FIRST_LD_V2_DATA_WORKSPACE_IDX_INDEX]);
}
for (uint32_t i = 0; i < AIV_CORE_NUM; ++i) {
printf("AIV Core%u\n", i);
printf(" Core Enable : %u\n", result.fdData[i][LD_V2_CORE_ENABLE_INDEX]);
printf(" FD Task BN2 Idx : %u\n", result.fdData[i][LD_V2_BN2_IDX_INDEX]);
printf(" FD Task M Idx : %u\n", result.fdData[i][LD_V2_M_IDX_INDEX]);
printf(" FD Task S2 Idx : %u\n", result.fdData[i][LD_V2_WORKSPACE_IDX_INDEX]);
printf(" FD Task Workspace Num : %u\n", result.fdData[i][LD_V2_WORKSPACE_NUM_INDEX]);
printf(" FD Subtask M Start : %u\n", result.fdData[i][LD_V2_M_START_INDEX]);
printf(" FD Subtask M Num : %u\n", result.fdData[i][LD_V2_M_NUM_INDEX]);
}
return 0;
}