aclnnRecurrentGatedDeltaRule

📄 查看源码

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 接口功能:完成变步长的Recurrent Gated Delta Rule计算。

  • 计算公式:

    Recurrent Gated Delta Rule(循环门控Delta规则,RGDR)是一种应用于循环神经网络的算子,也被应用于一种线性注意力机制中。 在每个时间步 tt,网络根据当前的输入 qtq_tktk_tvtv_t 和上一个隐藏状态 St−1S_{t-1},计算当前的注意力输出 oto_t 和新的隐藏状态 StS_t。 在这个过程中,门控单元会决定有多少新信息存入隐藏状态,以及有多少旧信息需要被遗忘。

    St:=St−1(αtDiag(αkt)(I−βtktktT))+βtvtktT=αtDiag(αkt)St−1+βt(vt−αtDiag(αkt)St−1kt)ktTS_t := S_{t-1}(\alpha_t Diag(\alpha_{kt})(I - \beta_t k_t k_t^T)) + \beta_t v_t k_t^T = \alpha_t Diag(\alpha_{kt})S_{t-1} + \beta_t (v_t - \alpha_t Diag(\alpha_{kt})S_{t-1}k_t)k_t^T

    o:=Stqtdko := \frac{S_t q_t}{\sqrt{d_k}}

    其中,St−1,St∈Rdv×dkS_{t-1},S_t \in R^{d_v \times d_k}qt,kt∈Rdkq_t, k_t \in R^{d_k}vt∈Rdvv_t \in R^{d_v}αt∈R\alpha_t \in Rαkt∈Rdk\alpha_{kt} \in R^{d_k}βt∈R\beta_t \in Ro∈Rdvo \in R^{d_v}

函数原型

每个算子分为两段式接口,必须先调用“aclnnRecurrentGatedDeltaRuleGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnRecurrentGatedDeltaRule”接口执行计算。

aclnnStatus aclnnRecurrentGatedDeltaRuleGetWorkspaceSize(
    const aclTensor *query,
    const aclTensor *key,
    const aclTensor *value,
    const aclTensor *beta,
    aclTensor       *stateRef,
    const aclTensor *actualSeqLengths,
    const aclTensor *ssmStateIndices,
    const aclTensor *g,
    const aclTensor *gk,
    const aclTensor *numAcceptedTokens,
    float           scaleValue,
    aclTensor       *out,
    uint64_t        *workspaceSize,
    aclOpExecutor   **executor)
aclnnStatus aclnnRecurrentGatedDeltaRule(
    void          *workspace,
    uint64_t      workspaceSize,
    aclOpExecutor *executor,
    aclrtStream   stream)

aclnnRecurrentGatedDeltaRuleGetWorkspaceSize

  • 参数说明

    参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor
    query 输入 公式中的q。 不支持空Tensor。 BFLOAT16 ND (T, Nk, Dk)
    key 输入 公式中的k。 不支持空Tensor。 BFLOAT16 ND (T, Nk, Dk)
    value 输入 公式中的v。 不支持空Tensor。 BFLOAT16 ND (T, Nv, Dv)
    beta 输入 公式中的β。 不支持空Tensor。 BFLOAT16 ND (T, Nv)
    stateRef 输入&输出 状态矩阵,公式中的S。 不支持空Tensor。 BFLOAT16 ND (BlockNum, Nv, Dv, Dk) ×
    actualSeqLengths 输入 不同batch的有效序列长度。 不支持空Tensor。 INT32 ND (B,)
    ssmStateIndices 输入 输入序列到状态矩阵的映射索引。
    • 不支持空Tensor。
    • state[ssmStateIndices[i]]表示第i个token的状态矩阵。
    INT32 ND (T,)
    g 输入 衰减系数,公式中的α=e^g。
    • 不支持空Tensor。
    • 如果传入nullptr,则表示全0的tensor。
    FLOAT32 ND (T, Nv)
    gk 输入 衰减系数,公式中的αk=e^gk
    • 不支持空Tensor。
    • 如果传入nullptr,则表示全0的tensor。
    FLOAT32 ND (T, Nv, Dk)
    numAcceptedTokens 输入 每个序列接受的token数量。 不支持空Tensor。 INT32 ND (B,)
    scaleValue 输入 query的缩放因子,对应公式中的 1/sqrt(d_k)。 - - - - -
    out 输出 公式中的o。 - BFLOAT16 ND (T, Nv, Dv)
    workspaceSize 输出 返回需要在Device侧申请的workspace大小。 - - - - -

    其中 BB 表示batch size,令 LiL_i 表示第i个序列的长度,则 T=∑iBLiT=\sum_i^B L_i 表示累积序列长度。NkN_k 表示key的头数,NvN_v 表示value的头数,DkD_k 表示key向量的维度,DvD_v 表示value向量的维度。

  • 返回值

    aclnnStatus: 返回状态码,具体参见aclnn返回码

    第一段接口完成入参校验,出现以下场景时报错:

    返回值 错误码 描述
    ACLNN_ERR_PARAM_NULLPTR 161001 query, key, value, beta, stateRef, actualSeqLengths, ssmStateIndices, numAcceptedTokens, out存在空指针。
    ACLNN_ERR_PARAM_INVALID 161002 输入Tensor的数据类型不在支持的范围内。
    输入Tensor的数据格式不在支持范围内。
    输入Tensor的shape不在支持范围内。

aclnnRecurrentGatedDeltaRule

  • 参数说明

    参数名 输入/输出 描述
    workspace 输入 在Device侧申请的workspace内存地址。
    workspaceSize 输入 在Device侧申请的workspace大小,由第一段接口aclnnRecurrentGatedDeltaRuleGetWorkspaceSize获取。
    executor 输入 op执行器,包含算子计算流程。
    stream 输入 指定执行任务的Stream。
  • 返回值

    aclnnStatus: 返回状态码,具体参见aclnn返回码

约束说明

  • 确定性计算:
    • aclnnRecurrentGatedDeltaRule默认确定性实现。
  • 输入shape大小需满足约束:0<Li≤80 < L_i \le 80<Nk≤2560 < N_k \le 256Nk≤Nv≤256N_k \le N_v \le 256NvN_v % Nk==0N_k == 00<Dk≤5120 < D_k \le 5120<Dv≤5120 < D_v \le 5120<T0 < T0<B0 < BT≤BlockNumT \le BlockNum
  • 以下约束由于算子无法获取tensor中具体数值,故需用户保证,算子不校验:
    • ssmStateIndices[i]<BlockNumssmStateIndices[i] < BlockNum
    • 0<actualSeqLengths[i]≤80 < actualSeqLengths[i] \le 8,且actualSeqLengths[i]actualSeqLengths[i]累加和等于TT
    • 1≤numAcceptedTokens[i]≤actualSeqLengths[i]1 \le numAcceptedTokens[i] \le actualSeqLengths[i]
    • −1≤query[i][j][k]≤1-1 \le query[i][j][k] \le 1
    • −1≤key[i][j][k]≤1-1 \le key[i][j][k] \le 1
    • g[i][j]<0g[i][j] < 0
    • gk[i][j][k]<0gk[i][j][k] < 0
    • 0<beta[i][j]<10 < beta[i][j] < 1

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例

#include <iostream>
#include <vector>
#include "acl/acl.h"
#include "aclnnop/aclnn_recurrent_gated_delta_rule.h"

#define CHECK_RET(cond, return_expr)                                                                                   \
    do {                                                                                                               \
        if (!(cond)) {                                                                                                 \
            return_expr;                                                                                               \
        }                                                                                                              \
    } while (0)

#define LOG_PRINT(message, ...)                                                                                        \
    do {                                                                                                               \
        printf(message, ##__VA_ARGS__);                                                                                \
    } while (0)

int64_t GetShapeSize(const std::vector<int64_t> &shape)
{
    int64_t shapeSize = 1;
    for (auto i : shape) {
        shapeSize *= i;
    }
    return shapeSize;
}

void PrintOutResult(std::vector<int64_t> &shape, void **deviceAddr)
{
    auto size = GetShapeSize(shape);
    std::vector<aclFloat16> resultData(size, 0);
    auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), *deviceAddr,
                           size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return);
    for (int64_t i = 0; i < size; i++) {
        if (i >= 5) { // print the first five data
            break;
        }
        LOG_PRINT("mean result[%ld] is: %f\n", i, aclFloat16ToFloat(resultData[i]));
    }
}

int Init(int32_t deviceId, aclrtContext *context, aclrtStream *stream)
{
    // AscendCL初始化
    auto ret = aclInit(nullptr);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
    ret = aclrtSetDevice(deviceId);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
    ret = aclrtCreateContext(context, deviceId);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateContext failed. ERROR: %d\n", ret); return ret);
    ret = aclrtSetCurrentContext(*context);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetCurrentContext failed. ERROR: %d\n", ret); return ret);
    ret = aclrtCreateStream(stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
    return 0;
}

template <typename T>
int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr,
                    aclDataType dataType, aclTensor **tensor)
{
    auto size = GetShapeSize(shape) * sizeof(T);
    // 调用aclrtMalloc申请device侧内存
    auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);

    // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
    ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);

    // 调用aclCreateTensor接口创建aclTensor
    *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, nullptr, 0, aclFormat::ACL_FORMAT_ND, shape.data(),
                              shape.size(), *deviceAddr);
    return ACL_SUCCESS;
}

int main()
{
    // 1.device/context/stream初始化,参考AscendCL对外接口列表
    // 根据自己的实际device填写deviceId
    int32_t deviceId = 0;
    aclrtContext context;
    aclrtStream stream;
    auto ret = Init(deviceId, &context, &stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret);

    // 2. 构造输入与输出,需要根据API的接口自定义构造
    void *queryDeviceAddr = nullptr;
    void *keyDeviceAddr = nullptr;
    void *valueDeviceAddr = nullptr;
    void *gamaDeviceAddr = nullptr;
    void *betaDeviceAddr = nullptr;
    void *stateRefDeviceAddr = nullptr;
    void *actSeqLenDeviceAddr = nullptr;
    void *ssmStaIdDeviceAddr = nullptr;
    void *numAccTokDeviceAddr = nullptr;
    void *attnOutDeviceAddr = nullptr;

    aclTensor *query = nullptr;
    aclTensor *key = nullptr;
    aclTensor *value = nullptr;
    aclTensor *gama = nullptr;
    aclTensor *gamak = nullptr;
    aclTensor *beta = nullptr;
    aclTensor *stateRef = nullptr;
    aclTensor *actSeqLen = nullptr;
    aclTensor *ssmStaId = nullptr;
    aclTensor *numAccTok = nullptr;
    aclTensor *attnOut = nullptr;

    // 自定义输入与属性
    int32_t batchSize = 2;
    int32_t mtp = 2;
    int32_t headKNum = 4;
    int32_t headVNum = 8;
    int32_t dimV = 32;
    int32_t dimK = 32;


    std::vector<int64_t> stateShape = {batchSize * mtp, headVNum, dimV, dimK};
    std::vector<int64_t> qkShape = {batchSize * mtp, headKNum, dimK};
    std::vector<int64_t> vShape = {batchSize * mtp, headVNum, dimV};
    std::vector<int64_t> gamaShape = {batchSize * mtp, headVNum};
    std::vector<int64_t> actSeqLenShape = {batchSize};
    std::vector<int64_t> ssmStaIdShape = {batchSize * mtp};
    std::vector<float> stateRefHostData(GetShapeSize(stateShape));
    std::vector<float> queryHostData(GetShapeSize(qkShape));
    std::vector<float> keyHostData(GetShapeSize(qkShape));
    std::vector<float> valueHostData(GetShapeSize(vShape));
    std::vector<float> gamaHostData(GetShapeSize(gamaShape));
    std::vector<float> betaHostData(GetShapeSize(gamaShape));
    std::vector<int32_t> actSeqLenHostData(batchSize, mtp);
    std::vector<int32_t> ssmStaIdHostData(batchSize * mtp);
    std::vector<int32_t> numAccTokHostData(batchSize, 1);
    for (int i = 0; i < stateRefHostData.size(); i++) {
        stateRefHostData[i] = 0.5;
    }
    for (int i = 0; i < queryHostData.size(); i++) {
        queryHostData[i] = 0.5;
    }
    for (int i = 0; i < keyHostData.size(); i++) {
        keyHostData[i] = 0.5;
    }
    for (int i = 0; i < valueHostData.size(); i++) {
        valueHostData[i] = 0.5;
    }
    for (int i = 0; i < betaHostData.size(); i++) {
        betaHostData[i] = 0.5;
    }
    for (int i = 0; i < ssmStaIdHostData.size(); i++) {
        ssmStaIdHostData[i] = i;
    }

    std::vector<float> attnOutHostData(valueHostData);

    ret = CreateAclTensor(stateRefHostData, stateShape, &stateRefDeviceAddr, aclDataType::ACL_BF16, &stateRef);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(queryHostData, qkShape, &queryDeviceAddr, aclDataType::ACL_BF16, &query);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(keyHostData, qkShape, &keyDeviceAddr, aclDataType::ACL_BF16, &key);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(valueHostData, vShape, &valueDeviceAddr, aclDataType::ACL_BF16, &value);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(gamaHostData, gamaShape, &gamaDeviceAddr, aclDataType::ACL_FLOAT, &gama);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(betaHostData, gamaShape, &betaDeviceAddr, aclDataType::ACL_BF16, &beta);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(actSeqLenHostData, actSeqLenShape, &actSeqLenDeviceAddr, aclDataType::ACL_INT32, &actSeqLen);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(ssmStaIdHostData, ssmStaIdShape, &ssmStaIdDeviceAddr, aclDataType::ACL_INT32, &ssmStaId);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(numAccTokHostData, actSeqLenShape, &numAccTokDeviceAddr, aclDataType::ACL_INT32, &numAccTok);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(attnOutHostData, vShape, &attnOutDeviceAddr, aclDataType::ACL_BF16, &attnOut);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    // 3. 调用CANN算子库API,需要修改为具体的Api名称
    uint64_t workspaceSize = 0;
    float scale = 1.0;
    aclOpExecutor *executor;
    // 调用aclnnRecurrentGatedDeltaRuleGetWorkspaceSize第一段接口
    ret = aclnnRecurrentGatedDeltaRuleGetWorkspaceSize(query, key, value, beta, stateRef, actSeqLen, ssmStaId, gama,
                                                       gamak, numAccTok, scale, attnOut, &workspaceSize, &executor);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnRecurrentGatedDeltaRuleGetWorkspaceSize failed. ERROR: %d\n", ret);
              return ret);

    // 根据第一段接口计算出的workspaceSize申请device内存
    void *workspaceAddr = nullptr;
    if (workspaceSize > 0) {
        ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret);
    }

    // 调用aclnnRecurrentGatedDeltaRule第二段接口
    ret = aclnnRecurrentGatedDeltaRule(workspaceAddr, workspaceSize, executor, stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnRecurrentGatedDeltaRule failed. ERROR: %d\n", ret); return ret);

    // 4. 同步等待任务执行结束
    ret = aclrtSynchronizeStream(stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret);

    // 5. 获取输出的值,将device侧内存上的结果拷贝至host侧,需要根据具体API的接口定义修改
    PrintOutResult(stateShape, &stateRefDeviceAddr);
    PrintOutResult(vShape, &attnOutDeviceAddr);

    // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改
    aclDestroyTensor(query);
    aclDestroyTensor(key);
    aclDestroyTensor(value);
    aclDestroyTensor(gama);
    aclDestroyTensor(beta);
    aclDestroyTensor(stateRef);
    aclDestroyTensor(actSeqLen);
    aclDestroyTensor(ssmStaId);
    aclDestroyTensor(numAccTok);
    aclDestroyTensor(attnOut);

    // 7. 释放device资源
    aclrtFree(query);
    aclrtFree(key);
    aclrtFree(value);
    aclrtFree(gama);
    aclrtFree(beta);
    aclrtFree(stateRef);
    aclrtFree(actSeqLen);
    aclrtFree(ssmStaId);
    aclrtFree(numAccTok);
    aclrtFree(attnOut);
    if (workspaceSize > 0) {
        aclrtFree(workspaceAddr);
    }
    aclrtDestroyStream(stream);
    aclrtDestroyContext(context);
    aclrtResetDevice(deviceId);
    aclFinalize();

    return 0;
}