aclnnLightingIndexerV2Metadata

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 x
Atlas A2 训练系列产品/Atlas A2 推理系列产品 x
Atlas 200I/500 A2 推理产品 x
Atlas 推理系列产品 x
Atlas 训练系列产品 x

功能说明

  • 接口功能:该接口为AI CPU算子接口,是aclnnLightningIndexerV2算子的前置算子接口。根据aclnnLightningIndexerV2算子接口的输入信息,计算并输出负载均衡结果。输出结果可以作为aclnnLightningIndexerV2算子接口的输入,减少aclnnLightningIndexerV2算子接口的执行耗时。

    该算子不建议单独使用,建议与aclnnLightningIndexerV2算子配合使用,形成完整的工作流。

    1. 接受aclnnLightningIndexerV2算子接口输入数据shape信息,包含batchSize、qSeqlen、kSeqlen、mask。通过对输入分块并模拟计算耗时,均匀分配分块到可用核上,以降低aclnnLightningIndexerV2算子的整体计算耗时,并提高硬件利用率。
    2. 分配结果输出后,后续作为输入供aclnnLightningIndexerV2算子使用。
    3. 分配结果包含每个AIC核基本块的起始点和终止点,已经每个AIV核的FD任务信息。详细内容可以参考调用示例

函数原型

每个算子分为两段式接口,必须先调用"aclnnLightingIndexerV2MetadataGetWorkspaceSize"获取workspace大小,在调用"aclnnLightingIndexerV2Metadata"执行计算

aclnnStatus aclnnLightningIndexerV2MetadataGetWorkspaceSize(
    const aclTensor   *cuSeqlensQOptional,
    const aclTensor   *cuSeqlensKOptional,
    const aclTensor   *sequsedQOptional,
    const aclTensor   *sequsedKOptional,
    const aclTensor   *cmpResidualKOptional,
    int64_t            numHeadsQ,
    int64_t            numHeadsK,
    int64_t            headDim,
    int64_t            topk,
    int64_t            batchSize,
    int64_t            maxSeqlenQ,
    int64_t            maxSeqlenK,
    char              *layoutQOptional,
    char              *layoutKOptional,
    int64_t            maskMode,
    int64_t            cmpRatio,
    const aclTensor   *metaData,
    uint64_t          *workspaceSize,
    aclOpExecutor    **executor)
aclnnStatus aclnnLightningIndexerV2Metadata(
    void              *workspace,
    uint64_t           workspaceSize,
    aclOpExecutor     *executor,
    aclrtStream        stream)

aclnnLightingIndexerV2MetadataGetWorkspaceSize

  • 参数说明

    参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor
    cuSeqlensQOptional 输入 表示不同Batch中Query的有效Sequence Length
    TND场景下必传,以该入参的数量作为Batch值
    第一个值为额外值并固定为0。
    支持空Tensor INT32 ND 1维,shape固定为(B+1,)
    cuSeqlensKOptional 输入 表示不同Batch中Key的有效Sequence Length
    TND场景下必传,以该入参的数量作为Batch值
    第一个值为额外值并固定为0。
    支持空Tensor INT32 ND 1维,shape固定为(B+1,)
    sequsedQOptional 输入 表示不同Batch中Query实际参与运算的Sequence Length。 支持空Tensor INT32 ND 1维,shape固定为(B,)
    sequsedKOptional 输入 表示不同Batch中Key实际参与运算的Sequence Length。 支持空Tensor INT32 ND 1维,shape固定为(B,)
    cmpResidualKOptional 输入 预留参数,表示不同Batch中Key的Sequence Length的余数,当前不影响算子计算效果。
    如果cmpRatio不为1,且mask为3,则必须传入cmpResidualKOptional。
    支持空Tensor INT32 ND 1维,shape固定为(B,)
    numHeadsQ 输入 表示Query的head个数。 支持非负数 INT64 - - -
    numHeadsK 输入 表示Key的head个数。 支持非负数 INT64 - - -
    headDim 输入 表示token数。 支持非负数 INT64 - - -
    topk 输入 表示从Query中筛选出的关键稀疏token的个数。 支持非负数 INT64 - - -
    batchSize 输入 表示Batch数量。 支持非负数
    建议值为0
    INT64 - - -
    maxSeqlenQ 输入 表示Query的最长Sequence Length。 支持非负数
    建议值为0
    INT64 - - -
    maxSeqlenK 输入 表示Key的最长Sequence Length。 支持非负数
    建议值为0
    INT64 - - -
    layoutQOptional 输入 表示Query的排列格式。 支持 BSND、TND
    建议值为BSND
    INT64 - - -
    layoutKOptional 输入 表示Key的排列格式。 支持 BSND、TND、PA_BBND
    建议值为BSND
    INT64 - - -
    maskMode 输入 表示sparse模式。 0: No mask
    3: rightDownCausal模式的mask,对应以右顶点为划分的下三角场景
    建议值为0
    INT64 - - -
    cmpRatio 输入 预留参数,表示Key的压缩率,当前不影响算子计算效果。
    建议值1,表示无压缩。
    取值范围[1,128] INT64 - - -
    metaData 输出 表示负载均衡结果输出。 - INT32 ND 1维,shape固定为(1024) ×
    workspaceSize 输出 返回需要在Device侧申请的workspace大小。 - - - - -
    executor 输出 返回op执行器,包含了算子计算流程。 - - - - -
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码

aclnnLightingIndexerV2Metadata

  • 参数说明:

    参数名 输入/输出 描述
    workspace 输入 在Device侧申请的workspace内存地址
    workspaceSize 输入 在Device侧申请的workspace大小,由第一段接口aclnnLightingIndexerV2MetadataGetWorkspaceSize获取
    executor 输入 op执行器,包含了算子计算流程
    stream 输入 指定执行任务的Stream
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码

约束说明

  • aclnnLightingIndexerV2Metadata默认确定性实现

  • Batch取值规则

    • 优先获取sequsedQOptional中的Batch信息
    • 如果未传入sequsedQOptional,优先获取cuSeqlensQOptional中的Batch信息
    • 如果未传入sequsedQOptional,且layoutQOptional为TND,则必获取cuSeqlensQOptional中的Batch信息
    • 除上所述,使用batchSize
  • Sequence Length取值规则

    • 优先获取sequsedQOptional中的Sequence Length信息
    • 如果未传入sequsedQOptional,且layoutQOptional为TND,则必获取cuSeqlensQOptional中的Sequence Length信息
    • 除上所述,使用maxSeqlenQ
    • Key与Query的获取规则一致
  • layout约束

    • 当layoutKOptional为PA_BBND时,layoutQOptional可以任意取值
    • 除上所述,layoutQOptional必须与layoutKOptional保持一致
  • BSND场景

    • 当传入的layoutQOptional为"BSND"时,视为使用BSND场景
    • 在未传入cuSeqlensQOptional和sequsedQOptional的情况下,必传batchSize、maxSeqlenQ、maxSeqlenK参数
  • TND场景

    • 当传入的layoutQOptional为"TND"时,视为使用TND场景
    • 必传cuSeqlensQOptional、cuSeqlensKOptional参数

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例

#include <iostream>
#include <vector>
#include <cmath>
#include <cstring>
#include <limits>
#include <functional>
#include <utility>
#include "acl/acl.h"
#include "aclnnop/aclnn_lightning_indexer_v2_metadata.h"

#define CHECK_LOG_RET(cond, ret_val, fmt, ...)      \
    do {                                            \
        if (!(cond)) {                              \
            printf(fmt "\n", ##__VA_ARGS__);        \
            return (ret_val);                       \
        }                                           \
    } while (0)

// 参考 lightning_indexer_v2_metadata.h
constexpr uint32_t AIC_CORE_NUM = 36;
constexpr uint32_t AIV_CORE_NUM = 72;
constexpr uint32_t LI_V2_METADATA_SIZE = 8;
constexpr uint32_t LD_V2_METADATA_SIZE = 8;

// LI Metadata Index Definitions
constexpr uint32_t LI_V2_CORE_ENABLE_INDEX = 0;
constexpr uint32_t LI_V2_BN2_START_INDEX = 1;
constexpr uint32_t LI_V2_M_START_INDEX = 2;
constexpr uint32_t LI_V2_S2_START_INDEX = 3;
constexpr uint32_t LI_V2_BN2_END_INDEX = 4;
constexpr uint32_t LI_V2_M_END_INDEX = 5;
constexpr uint32_t LI_V2_S2_END_INDEX = 6;
constexpr uint32_t LI_V2_FIRST_LD_V2_DATA_WORKSPACE_IDX_INDEX = 7;

// LD Metadata Index Definitions
constexpr uint32_t LD_V2_CORE_ENABLE_INDEX = 0;
constexpr uint32_t LD_V2_BN2_IDX_INDEX = 1;
constexpr uint32_t LD_V2_M_IDX_INDEX = 2;
constexpr uint32_t LD_V2_WORKSPACE_IDX_INDEX = 3;
constexpr uint32_t LD_V2_WORKSPACE_NUM_INDEX = 4;
constexpr uint32_t LD_V2_M_START_INDEX = 5;
constexpr uint32_t LD_V2_M_NUM_INDEX = 6;

struct LiV2MetaData {
    uint32_t faData[AIC_CORE_NUM][LI_V2_METADATA_SIZE];
    uint32_t fdData[AIV_CORE_NUM][LD_V2_METADATA_SIZE];
};

struct ScopeGuard
{
    explicit ScopeGuard(std::function<void()> onExitScope) : m_exitFunc(std::move(onExitScope)),
        m_isDismissed(false) {}
    // 禁止拷贝
    ScopeGuard(const ScopeGuard&) = delete;
    ScopeGuard& operator=(const ScopeGuard&) = delete;

    ~ScopeGuard()
    {
        if (!m_isDismissed) {
            m_exitFunc();
        }
    }

    void Dismiss()
    {
        m_isDismissed = true;
    }

    std::function<void()> m_exitFunc;
    bool m_isDismissed;
};

struct Tensor {
    void *hostAddr { nullptr };
    void *deviceAddr { nullptr };
    aclTensor *data { nullptr };
};

struct ArgScenario {
    bool hasCuSeq { false };
    bool hasSeqused { false };
};

struct ArgContext {
    // required input
    int64_t numHeadsQ { 0 };
    int64_t numHeadsK { 0 };
    int64_t headDim { 0 };
    int64_t topk { 0 };
    // optional input
    Tensor cuSeqlensQOptional {};
    Tensor cuSeqlensKOptional {};
    Tensor sequsedQOptional {};
    Tensor sequsedKOptional {};
    Tensor cmpResidualKOptional {};
    int64_t batchSize { 0 };
    int64_t maxSeqlenQ { 0 };
    int64_t maxSeqlenK { 0 };
    char *layoutQOptional { nullptr };
    char *layoutKOptional { nullptr };
    int64_t maskMode { 0 };
    int64_t cmpRatio { 0 };
    // output
    Tensor metaData {};
};

int64_t GetShapeSize(const std::vector<int64_t>& shape) 
{
    int64_t shapeSize = 1;
    for (auto i : shape) {
        shapeSize *= i;
    }
    return shapeSize;
}

aclnnStatus Init(int32_t deviceId, aclrtStream* stream) 
{
    // 固定写法,初始化
    auto ret = aclInit(nullptr);
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclInit failed. ERROR: %d", ret);
    ret = aclrtSetDevice(deviceId);
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtSetDevice failed. ERROR: %d", ret);
    ret = aclrtCreateStream(stream);
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtCreateStream failed. ERROR: %d", ret);
    return ACL_SUCCESS;
}

void Finalize(int32_t deviceId, aclrtStream stream) 
{
    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();
}

aclnnStatus CreateTensor(aclDataType dataType, const std::vector<int64_t> &shape, Tensor &tensor)
{
    auto size = GetShapeSize(shape) * aclDataTypeSize(dataType);
    // 调用aclrtMallocHost申请host侧内存
    auto ret = aclrtMallocHost(&(tensor.hostAddr), size);
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtMallocHost failed. ERROR: %d", ret);
    memset(tensor.hostAddr, 0, size);
    // 调用aclrtMalloc申请device侧内存
    ret = aclrtMalloc(&(tensor.deviceAddr), size, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtMalloc failed. ERROR: %d", ret);
    // 调用aclCreateTensor接口创建aclTensor
    tensor.data = aclCreateTensor(shape.data(), shape.size(), dataType, nullptr, 0, aclFormat::ACL_FORMAT_ND,
        shape.data(), shape.size(), tensor.deviceAddr);

    // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
    ret = aclrtMemcpy(tensor.deviceAddr, size, tensor.hostAddr, size, ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtMemcpy failed. ERROR: %d", ret);
    return ACL_SUCCESS;
}

void DestroyTensor(Tensor &tensor)
{
    if (tensor.data != nullptr) {
        aclDestroyTensor(tensor.data);
        tensor.data = nullptr;
    }
    if (tensor.deviceAddr != nullptr) {
        aclrtFree(tensor.deviceAddr);
        tensor.deviceAddr = nullptr;
    }
    if (tensor.hostAddr != nullptr) {
        aclrtFreeHost(tensor.hostAddr);
        tensor.hostAddr = nullptr;
    }
}

void DestroyArgs(ArgContext &context)
{
    DestroyTensor(context.metaData);
    DestroyTensor(context.cuSeqlensQOptional);
    DestroyTensor(context.cuSeqlensKOptional);
    DestroyTensor(context.sequsedQOptional);
    DestroyTensor(context.sequsedKOptional);
    DestroyTensor(context.cmpResidualKOptional);

    if (context.layoutQOptional != nullptr) {
        free(context.layoutQOptional);
        context.layoutQOptional = nullptr;
    }
    if (context.layoutKOptional != nullptr) {
        free(context.layoutKOptional);
        context.layoutKOptional = nullptr;
    }
}

aclnnStatus CreateArgs(const ArgScenario &scenario, ArgContext &context)
{
    ScopeGuard argsGuard([&] { DestroyArgs(context); });
    aclnnStatus ret;
    int64_t batchSize = 4;

    context.numHeadsQ = 1;
    context.numHeadsK = 1;
    context.headDim = 128;
    context.topk = 0;
    ret = CreateTensor(aclDataType::ACL_INT32, { 1024 }, context.metaData);     // 1024: Fix size
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Create meta failed. Error: %d", ret);

    context.maskMode = 0;                   // 0: no mask, 3: causal
    context.cmpRatio = 1;                   // [1, 128], 1: no compress
    context.layoutQOptional = (char *)malloc(sizeof(char) * 16);
    context.layoutKOptional = (char *)malloc(sizeof(char) * 16);
    strcpy(context.layoutQOptional, "BSND");                // BSND,TND
    strcpy(context.layoutKOptional, "BSND");                // BSND,TND,PA_BBND

    if (!scenario.hasCuSeq && !scenario.hasSeqused) {
        context.batchSize = batchSize;
        context.maxSeqlenK = 1024;
        context.maxSeqlenQ = 1024;
        return ACL_SUCCESS;
    }

    if (scenario.hasCuSeq) {
        // (B+1,), first element is always 0
        ret = CreateTensor(aclDataType::ACL_INT32, { batchSize + 1 }, context.cuSeqlensQOptional);
        CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Create cuSeqlensQOptional failed. Error: %d", ret);
        ret = CreateTensor(aclDataType::ACL_INT32, { batchSize + 1 }, context.cuSeqlensKOptional);
        CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Create cuSeqlensKOptional failed. Error: %d", ret);
    }

    if (scenario.hasSeqused) {
        // (B,)
        ret = CreateTensor(aclDataType::ACL_INT32, { batchSize }, context.sequsedQOptional);
        CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Create sequsedQOptional failed. Error: %d", ret);
        ret = CreateTensor(aclDataType::ACL_INT32, { batchSize }, context.sequsedKOptional);
        CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Create sequsedKOptional failed. Error: %d", ret);
    }

    argsGuard.Dismiss();
    return ACL_SUCCESS;
}

int main() {
    // 1. (固定写法)device/stream初始化,参考对外接口列表
    // 根据自己的实际device填写deviceId
    int32_t deviceId = 0;
    aclrtStream stream;
    auto ret = Init(deviceId, &stream);
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Init acl failed. ERROR: %d", ret);
    ScopeGuard sysGuard([&] { Finalize(deviceId, stream); });

    // 2. 构造输入与输出,需要根据API的接口定义构造
    ArgScenario scenario {};
    scenario.hasCuSeq = true;
    scenario.hasSeqused = true;
    ArgContext context {};
    ret = CreateArgs(scenario, context);
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "Create input arguments failed. ERROR: %d", ret);
    ScopeGuard argsGuard([&] { DestroyArgs(context); });

    // 3. 调用CANN算子库API,需要修改为具体的API
    // 调用aclnnLightningIndexerV2Metadata第一段接口
    uint64_t workspaceSize = 0;
    aclOpExecutor *executor = nullptr;
    void *workspaceAddr = nullptr;
    ret = aclnnLightningIndexerV2MetadataGetWorkspaceSize(
        context.cuSeqlensQOptional.data, context.cuSeqlensKOptional.data, context.sequsedQOptional.data,
        context.sequsedKOptional.data, context.cmpResidualKOptional.data,
        context.numHeadsQ, context.numHeadsK, context.headDim, context.topk,
        context.batchSize, context.maxSeqlenQ, context.maxSeqlenK, context.layoutQOptional,
        context.layoutKOptional, context.maskMode, context.cmpRatio,
        context.metaData.data, &workspaceSize, &executor);
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclnnLightningIndexerV2MetadataGetWorkspaceSize failed. ERROR: %d\n", ret);

    if (workspaceSize > static_cast<uint64_t>(0)) {
        ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "allocate workspace failed. ERROR: %d\n", ret);
    }
    ScopeGuard workspaceGuard([&] {
        if (workspaceAddr != nullptr) {
            aclrtFree(workspaceAddr);
            workspaceAddr = nullptr;
        }
    });
    
    // 调用aclnnLightningIndexerV2Metadata第二段接口
    ret = aclnnLightningIndexerV2Metadata(workspaceAddr, workspaceSize, executor, stream);
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclnnLightningIndexerV2Metadata failed. ERROR: %d\n", ret);

    // 4. (固定写法)同步等待任务执行结束
    ret = aclrtSynchronizeStream(stream);
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtSynchronizeStream failed. ERROR: %d\n", ret);

    // 5. 打印输出
    LiV2MetaData result {};
    ret = aclrtMemcpy(&result, sizeof(result), context.metaData.deviceAddr, sizeof(result), ACL_MEMCPY_DEVICE_TO_HOST);
    CHECK_LOG_RET(ret == ACL_SUCCESS, ret, "aclrtMemcpy failed. ERROR: %d\n", ret);

    for (uint32_t i = 0; i < AIC_CORE_NUM; ++i) {
        printf("AIC Core%u\n", i);
        printf("    Core Enable : %u\n", result.faData[i][LI_V2_CORE_ENABLE_INDEX]);
        printf("    Start BN2   : %u\n", result.faData[i][LI_V2_BN2_START_INDEX]);
        printf("    Start M     : %u\n", result.faData[i][LI_V2_M_START_INDEX]);
        printf("    Start S2    : %u\n", result.faData[i][LI_V2_S2_START_INDEX]);
        printf("    End BN2     : %u\n", result.faData[i][LI_V2_BN2_END_INDEX]);
        printf("    End M       : %u\n", result.faData[i][LI_V2_M_END_INDEX]);
        printf("    End S2      : %u\n", result.faData[i][LI_V2_S2_END_INDEX]);
        printf("    First Worksapce Index : %u\n", result.faData[i][LI_V2_FIRST_LD_V2_DATA_WORKSPACE_IDX_INDEX]);
    }
    for (uint32_t i = 0; i < AIV_CORE_NUM; ++i) {
        printf("AIV Core%u\n", i);
        printf("    Core Enable             : %u\n", result.fdData[i][LD_V2_CORE_ENABLE_INDEX]);
        printf("    FD Task BN2 Idx         : %u\n", result.fdData[i][LD_V2_BN2_IDX_INDEX]);
        printf("    FD Task M Idx           : %u\n", result.fdData[i][LD_V2_M_IDX_INDEX]);
        printf("    FD Task S2 Idx          : %u\n", result.fdData[i][LD_V2_WORKSPACE_IDX_INDEX]);
        printf("    FD Task Workspace Num   : %u\n", result.fdData[i][LD_V2_WORKSPACE_NUM_INDEX]);
        printf("    FD Subtask M Start      : %u\n", result.fdData[i][LD_V2_M_START_INDEX]);
        printf("    FD Subtask M Num        : %u\n", result.fdData[i][LD_V2_M_NUM_INDEX]);
    }

    return 0;
}