aclnnMhcPreBackward

📄 查看源码

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 接口功能:MhcPreBackward是MhcPre的反向算子,MhcPre算子基于一系列计算得到mHC(Manifold-Constrained Hyper-Connections)架构中的HresH^{res}HpostH^{post}投影矩阵以及Atten或MLP层的输入矩阵hinh^{in}

  • 计算公式:

    • 输出组合梯度计算

      • 正向公式:

        H_in=∑i=1Nx[B,S,i,:]⋅H_pren[B,S,i]H\_in = \sum_{i=1}^{N} x[{B,S,i,:}] · H\_pre_n[B,S,i]

      • 反向计算:

        H_pre_grad=Reduce(H_in_grad.unsqueeze(−2)⊙x,dim=−1)([B,S,N])x_grad_vec3=H_in_grad×H_pre([B,S,N,D])\begin{aligned} H\_pre\_grad &= \text{Reduce}\left(H\_in\_grad.\text{unsqueeze}(-2) \odot x, \text{dim}=-1\right) \quad ([B,S,N]) \\ x\_grad\_vec3 &= H\_in\_grad \times H\_pre \quad ([B,S,N,D]) \end{aligned}

    • Sigmoid门控反向(H_pre)

      • 正向公式:

      H_pre=Sigmoid(α_pre∗H_pre_1+bias_pre)+hc_epsH\_pre = \text{Sigmoid}(\alpha\_pre * H\_pre\_1 + bias\_pre) + hc\_eps

      • 反向计算:

        s=H_pre−hc_epsH_pre_2_grad=H_pre_grad⊙s⊙(1−s)H_pre_1_grad=H_pre_2_grad⋅α_preα_pre_grad=∑b,s,nB,S,N(H_pre_2_grad⋅H_pre_1)bias_pre_grad=∑b,sB,SH_pre_2_grad([N])\begin{aligned} s &= H\_pre - hc\_eps \\ H\_pre\_2\_grad &= H\_pre\_grad \odot s \odot (1 - s) \\ H\_pre\_1\_grad &= H\_pre\_2\_grad \cdot \alpha\_pre \\ \alpha\_pre\_grad &= \sum_{b,s,n}^{B,S,N} \left(H\_pre\_2\_grad \cdot H\_pre\_1\right) \\ bias\_pre\_grad &= \sum_{b,s}^{B,S} H\_pre\_2\_grad \quad ([N]) \end{aligned}

    • Sigmoid门控反向(H_post)

      • 正向公式:

      H_post=Sigmoid(α_post∗H_post_1+bias_post)∗2H\_post = \text{Sigmoid}(\alpha\_post * H\_post\_1 + bias\_post) * 2

      • 反向计算:

        H_post_2_grad=H_post_grad⊙(H_post⋅(1−H_post2))H_post_1_grad=H_post_2_grad⋅αpostαpost_grad=∑b,s,nB,S,N(H_post_2_grad⋅Hpost_1)bias_post_grad=∑b,sB,SH_post_2_grad([N])\begin{aligned} H\_post\_2\_grad &= H\_post\_grad \odot \left(H\_post \cdot \left(1 - \frac{H\_post}{2}\right)\right) \\ H\_post\_1\_grad &= H\_post\_2\_grad \cdot \alpha_{post} \\ \alpha_{post\_grad} &= \sum_{b,s,n}^{B,S,N} \left(H\_post\_2\_grad \cdot H_{post\_1}\right) \\ bias\_post\_grad &= \sum_{b,s}^{B,S} H\_post\_2\_grad \quad ([N]) \end{aligned}

    • 残差连接反向(H_res)

      • 正向公式:

      H_res=α_res∗H_res_1+bias_resH\_res = \alpha\_res * H\_res\_1 + bias\_res

      • 反向计算:

        H_res_2_grad=H_res_grad⋅αres([B,S,N,N])α_res_grad=∑b,s,i,jB,S,N,N(H_res_grad⋅H_res_2)bias_res_grad=∑b,sB,SH_res_grad([N,N])H_res_1_grad=Reshape(H_res_2_grad)([B,S,N2])\begin{aligned} H\_res\_2\_grad &= H\_res\_grad \cdot \alpha_{res} \quad ([B,S,N,N]) \\ \alpha\_res\_grad &= \sum_{b,s,i,j}^{B,S,N,N} \left(H\_res\_grad \cdot H\_res\_2\right) \\ bias\_res\_grad &= \sum_{b,s}^{B,S} H\_res\_grad \quad ([N,N]) \\ H\_res\_1\_grad &= \text{Reshape}(H\_res\_2\_grad) \quad ([B,S,N^2]) \end{aligned}

    • RMSNorm Fusion反向

      • 正向公式:

      H_mix_tmp=H_mix∗inv_rmsH\_mix\_tmp = H\_mix * inv\_rms

      • 反向计算:

        H_mix_tmp_grad=Concat(H_pre_1_grad,H_post_1_grad,H_res_1_grad)([B,S,2N+N2])H_mix_grad=H_mix_tmp_grad⋅inv_rmsinv_rmsgrad=∑last_dim(H_mix_tmp_grad⋅H_mix)([B,S,1])\begin{aligned} H\_mix\_tmp\_grad &= \text{Concat}(H\_pre\_1\_grad, H\_post\_1\_grad, H\_res\_1\_grad) \quad ([B,S,2N+N^2]) \\ H\_mix\_grad &= H\_mix\_tmp\_grad \cdot inv\_rms \\ inv\_rms_{grad} &= \sum_{\text{last\_dim}} \left(H\_mix\_tmp\_grad \cdot H\_mix\right) \quad ([B,S,1]) \end{aligned}

    • 矩阵乘法反向

      • 正向公式:

      H_mix=x_rs@phiTH\_mix = x\_rs @ phi^T

      x_rs=x∗gammax\_rs = x * gamma

      • 反向计算:

        x_rs_grad=H_mix_grad@phi([B,S,ND])X=Reshape(x_rs,[B⋅S,ND])G=Reshape(H_mix_grad,[B⋅S,2N+N2])phigrad=GT@X([2N+N2,ND])\begin{aligned} x\_rs\_grad &= H\_mix\_grad @ phi \quad ([B,S,ND]) \\ X &= \text{Reshape}(x\_rs, [B\cdot S, ND]) \\ G &= \text{Reshape}(H\_mix\_grad, [B\cdot S, 2N+N^2]) \\ phi_{grad} &= G^T @ X \quad ([2N+N^2, ND]) \end{aligned}

    • 特征缩放反向

      • 正向公式:

      x_rs=x∗gammax\_rs = x * gamma

      • 反向计算:

        x_grad_mm=x_rs_grad∗gammagamma_grad=∑b=1B∑s=1S(x∗x_rs_grad)([N,D])\begin{aligned} x\_grad\_mm &= x\_rs\_grad * gamma \\ gamma\_grad &= \sum_{b=1}^{B}\sum_{s=1}^{S} (x * x\_rs\_grad)\quad ([N,D]) \end{aligned}

    • RMS归一化梯度计算

      • 正向公式:

        inv_rms=11n∑i=1nxi2+eps,其中 n=N∗Dinv\_rms = \frac{1}{\sqrt{\frac{1}{n}\sum_{i=1}^{n}x_i^2 + eps}}, \quad 其中\ n = N * D

      • 反向计算:

        x_rs_grad_inv=−(inv_rms_grad⋅inv_rms3N∗D)⋅x_rsx_rs_grad=x_grad_mm+x_rs_grad_invx_grad_vec1=Reshape(x_rs_grad,[B,S,N,D])x_grad=x_grad_vec3+x_grad_vec1\begin{aligned} x\_rs\_grad\_inv &= - \left(\frac{inv\_rms\_grad \cdot {inv\_rms}^3}{N*D}\right) \cdot x\_rs \\ x\_rs\_grad &= x\_grad\_mm + x\_rs\_grad\_inv \\ x\_grad\_vec1 &= \text{Reshape}(x\_rs\_grad, [B,S,N,D]) \\ x\_grad &= x\_grad\_vec3 + x\_grad\_vec1 \end{aligned}

    • 融合mhc_post的grad_x相加操作

      x_grad=x_grad+grad_x_post\begin{aligned} x\_grad &= x\_grad + grad\_x\_post \end{aligned}

函数原型

每个算子分为两段式接口,必须先调用“aclnnMhcPreBackwardGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnMhcPreBackward”接口执行计算。

aclnnStatus aclnnMhcPreBackwardGetWorkspaceSize(  
    const aclTensor     *x,
    const aclTensor     *phi,
    const aclTensor     *alpha,
    const aclTensor     *gradHIn,
    const aclTensor     *gradHPost,
    const aclTensor     *gradHRes,
    const aclTensor     *invRms,
    const aclTensor     *hMix,
    const aclTensor     *hPre,
    const aclTensor     *hPost,
    const aclTensor     *gamma,
    const aclTensor     *gradXPostOptional,
    float               hcEps,
    const aclTensor     *gradX,
    const aclTensor     *gradPhi,
    const aclTensor     *gradAlpha,
    const aclTensor     *gradBias,
    const aclTensor     *gradGamma,
    uint64_t            *workspaceSize,
    aclOpExecutor       **executor)
aclnnStatus aclnnMhcPreBackward(
    void             *workspace, 
    uint64_t          workspaceSize, 
    aclOpExecutor    *executor, 
    aclrtStream       stream)

aclnnMhcPreBackwardGetWorkspaceSize

  • 参数说明:

    参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor
    x 输入 待计算数据,表示网络中mHC层的输入数据。
  • 不支持空Tensor。
  • BFLOAT16、FLOAT16 ND (B,S,N,D)、(T,N,D)
    B:支持泛化;S:支持泛化;T:B*S。
    phi 输入 mHC的参数矩阵。
  • 不支持空Tensor。
  • FLOAT32 ND (2N+N*N,N*D)
    N:与x的N保持一致;D:与x的D保持一致
    alpha 输入 mHC的缩放参数alpha。
  • 不支持空Tensor。
  • FLOAT32 - (3) -
    gradHIn 输入 hIn作为Atten/MLP层的输入。正向输出hIn对应的梯度。
  • 不支持空Tensor。
  • BFLOAT16、FLOAT16 ND (B,S,D)、(T,D)
    B:与x的B保持一致;S:与x的S保持一致;T:B*S;D:与x的D保持一致。
    gradHPost 输入 正向输出hPost对应的梯度。
  • 不支持空Tensor。
  • FLOAT32 ND (B,S,N)、(T,N)
    B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致。
    gradHRes 输入 正向输出hRes对应的梯度。
  • 不支持空Tensor。
  • FLOAT32 ND (B,S,N,N)、(T,N,N)
    B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致。
    invRms 输入 正向RmsNorm计算的invRms。
  • 不支持空Tensor。
  • FLOAT32 ND (B,S)、(T)
    B:与x的B保持一致;S:与x的S保持一致;T:B*S。
    hMix 输入 正向计算流x@phi的结果
  • 不支持空Tensor。
  • FLOAT32 ND (B,S,2N+N*N)、(T,2N+N*N)
    B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致。
    hPre 输入 正向sigmoid计算之后的hPre矩阵
  • 不支持空Tensor。
  • FLOAT32 ND (B,S,N)、(T,N)
    B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致。
    hPost 输入 正向的hPost输出
  • 不支持空Tensor。
  • FLOAT32 ND (B,S,N)、(T,N)
    B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致。
    gamma 可选输入 RmsNorm的缩放系数gamma
  • 不支持空Tensor。
  • 如果传入nullptr,则表示全1的tensor。
  • FLOAT32 ND (N,D)
    N:与x的N保持一致;D:与x的D保持一致。
    gradXPostOptional 可选输入 post反向输出的gradX
  • 不支持空Tensor。
  • 如果传入nullptr,则表示全0的tensor。
  • BFLOAT16、FLOAT16 ND (B,S,N,D)、(T,N,D)
    B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致;D:与x的D保持一致。
    hcEps 可选输入 HPre的sigmoid后的eps参数
  • 建议值为1e-6
  • FLOAT32 - - -
    gradX 输出 x对应的梯度。
  • 与输入x的维度、数据类型保持一致
  • BFLOAT16、FLOAT16 ND (B,S,N,D)、(T,N,D)
    B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致;D:与x的D保持一致。
    gradPhi 输出 phi对应的梯度。
  • 与输入phi的维度保持一致
  • FLOAT32 ND (2N+N*N,N*D)
    N:与x的N保持一致;D:与x的D保持一致。
    gradAlpha 输出 alpha对应的梯度。
  • 与输入alpha维度保持一致。
  • FLOAT32 ND (3) -
    gradBias 输出 bias对应的梯度。 - FLOAT32 ND (2N+N*N)
    N:与x的N保持一致。
    -
    gradGamma 可选输出 gamma对应的梯度。
  • 当输入gamma不为nullptr时,此变量才会输出。
  • 与输入gamma的Shape维度保持一致。
  • FLOAT32 ND (N,D)
    N:与x的N保持一致;D:与x的D保持一致。
    workspaceSize 输出 返回需要在Device侧申请的workspace大小。 - - - - -
    executor 输出 返回op执行器,包含了算子计算流程。 - - - - -
  • 返回值

    返回aclnnStatus状态码,具体参见aclnn返回码

    第一段接口完成入参校验,出现以下场景时报错:

    返回值 错误码 描述
    ACLNN_ERR_PARAM_NULLPTR 161001 必选参数或者输出是空指针。
    ACLNN_ERR_PARAM_INVALID 161002 输入变量,x、phi、gamma、alpha的数据类型和数据格式不在支持的范围内。
    ACLNN_ERR_RUNTIME_ERROR 361001 API内部调用npu runtime的接口异常。

aclnnMhcPreBackward

  • 参数说明:

    参数名 输入/输出 描述
    workspace 输入 在Device侧申请的workspace内存地址。
    workspaceSize 输入 在Device侧申请的workspace大小,由第一段接口aclnnMhcPreBackwardGetWorkspaceSize获取。
    executor 输入 op执行器,包含了算子计算流程。
    stream 输入 指定执行任务的Stream流。
  • 返回值

    返回aclnnStatus状态码,具体参见aclnn返回码

约束说明

  • 确定性计算:

    • aclnnMhcPreBackward默认采用确定性实现。
  • 规格约束

    规格项 规格 规格说明
    n 4、 6、8 n值目前支持4、 6、 8
    D 1~16384 D支持1~16384范围以内且64元素对齐

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例

#include <iostream>
#include <vector>
#include <numeric>
#include "acl/acl.h"
#include "aclnnop/aclnn_mhc_pre_backward.h"

#define CHECK_RET(cond, return_expr)                                                                                   \
    do {                                                                                                               \
        if (!(cond)) {                                                                                                 \
            return_expr;                                                                                               \
        }                                                                                                              \
    } while (0)

#define LOG_PRINT(message, ...)                                                                                        \
    do {                                                                                                               \
        printf(message, ##__VA_ARGS__);                                                                                \
    } while (0)

int64_t GetShapeSize(const std::vector<int64_t> &shape)
{
    int64_t shapeSize = 1;
    for (auto i : shape) {
        shapeSize *= i;
    }
    return shapeSize;
}

void PrintOutResult(std::vector<int64_t> &shape, void **deviceAddr, const char *name, size_t elemSize = sizeof(float))
{
    auto size = GetShapeSize(shape);
    size_t copyBytes = size * elemSize;
    if (elemSize == 2) {
        // BF16: read raw bytes, convert to float for display
        std::vector<uint16_t> rawData(size, 0);
        auto ret = aclrtMemcpy(rawData.data(), copyBytes, *deviceAddr, copyBytes, ACL_MEMCPY_DEVICE_TO_HOST);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return);
        LOG_PRINT("%s result (first 10 elements):\n", name);
        for (int64_t i = 0; i < std::min(size, (int64_t)10); i++) {
            union {
                uint32_t i;
                float f;
            } u;
            u.i = (uint32_t)rawData[i] << 16;
            LOG_PRINT("  [%ld] = %f\n", i, u.f);
        }
    } else {
        // float32
        std::vector<float> resultData(size, 0);
        auto ret = aclrtMemcpy(resultData.data(), copyBytes, *deviceAddr, copyBytes, ACL_MEMCPY_DEVICE_TO_HOST);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return);
        LOG_PRINT("%s result (first 10 elements):\n", name);
        for (int64_t i = 0; i < std::min(size, (int64_t)10); i++) {
            LOG_PRINT("  [%ld] = %f\n", i, resultData[i]);
        }
    }
}

int Init(int32_t deviceId, aclrtContext *context, aclrtStream *stream)
{
    // 固定写法,AscendCL初始化
    auto ret = aclInit(nullptr);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
    ret = aclrtSetDevice(deviceId);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
    ret = aclrtCreateContext(context, deviceId);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateContext failed. ERROR: %d\n", ret); return ret);
    ret = aclrtSetCurrentContext(*context);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetCurrentContext failed. ERROR: %d\n", ret); return ret);
    ret = aclrtCreateStream(stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
    return 0;
}

template <typename T>
int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr,
                    aclDataType dataType, aclTensor **tensor)
{
    auto size = GetShapeSize(shape) * sizeof(T);
    // 调用aclrtMalloc申请device侧内存
    auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
    // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
    ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);

    // 计算连续tensor的strides
    std::vector<int64_t> strides(shape.size(), 1);
    for (int64_t i = shape.size() - 2; i >= 0; i--) {
        strides[i] = shape[i + 1] * strides[i + 1];
    }

    // 调用aclCreateTensor接口创建aclTensor
    *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
                              shape.data(), shape.size(), *deviceAddr);
    return 0;
}

int main()
{
    // 1. (固定写法)device/context/stream初始化,参考AscendCL对外接口列表
    // 根据自己的实际device填写deviceId
    int32_t deviceId = 0;
    aclrtContext context;
    aclrtStream stream;
    auto ret = Init(deviceId, &context, &stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret);

    // 2. 构造输入与输出,需要根据API的接口自定义构造
    std::vector<int64_t> xShape = {1024, 4, 512};      // T, N, D
    std::vector<int64_t> phiShape = {24, 2048};        // 2 * N + N * N, N * D
    std::vector<int64_t> alphaShape = {3};             // 固定大小
    std::vector<int64_t> gradHInShape = {1024, 512};   // T, D
    std::vector<int64_t> gradHPostShape = {1024, 4};   // T, N
    std::vector<int64_t> gradHResShape = {1024, 4, 4}; // T, N, N
    std::vector<int64_t> invRmsShape = {1024};         // T
    std::vector<int64_t> hMixShape = {1024, 24};       // T, 2 * N + N * N
    std::vector<int64_t> hPreShape = {1024, 4};        // T, N
    std::vector<int64_t> hPostShape = {1024, 4};       // T, N
    std::vector<int64_t> gammaShape = {4, 512};        // N, D
    std::vector<int64_t> gradXShape = {1024, 4, 512};  // T, N, D
    std::vector<int64_t> gradPhiShape = {24, 2048};    // 2 * N + N * N, N * D
    std::vector<int64_t> gradAlphaShape = {3};         // 固定大小
    std::vector<int64_t> gradGammaShape = {4, 512};    // N, D
    std::vector<int64_t> gradBiasShape = {24};         // 2*N+N*N
    std::vector<int64_t> gradXPostOptionalShape = {1024, 4, 512};

    void *xDeviceAddr = nullptr;
    void *phiDeviceAddr = nullptr;
    void *alphaDeviceAddr = nullptr;
    void *gradHInDeviceAddr = nullptr;
    void *gradHPostDeviceAddr = nullptr;
    void *gradHResDeviceAddr = nullptr;
    void *invRmsDeviceAddr = nullptr;
    void *hMixDeviceAddr = nullptr;
    void *hPreDeviceAddr = nullptr;
    void *hPostDeviceAddr = nullptr;
    void *gammaDeviceAddr = nullptr;
    void *gradXDeviceAddr = nullptr;
    void *gradPhiDeviceAddr = nullptr;
    void *gradAlphaDeviceAddr = nullptr;
    void *gradBiasDeviceAddr = nullptr;
    void *gradGammaDeviceAddr = nullptr;
    void *gradXPostOptionalDeviceAddr = nullptr;

    aclTensor *x = nullptr;
    aclTensor *phi = nullptr;
    aclTensor *alpha = nullptr;
    aclTensor *gradHIn = nullptr;
    aclTensor *gradHPost = nullptr;
    aclTensor *gradHRes = nullptr;
    aclTensor *invRms = nullptr;
    aclTensor *hMix = nullptr;
    aclTensor *hPre = nullptr;
    aclTensor *hPost = nullptr;
    aclTensor *gamma = nullptr;
    aclTensor *gradX = nullptr;
    aclTensor *gradPhi = nullptr;
    aclTensor *gradAlpha = nullptr;
    aclTensor *gradBias = nullptr;
    aclTensor *gradGamma = nullptr;
    aclTensor *gradXPostOptional = nullptr;

    std::vector<short> xHostData(1024 * 4 * 512, 1.0);
    std::vector<float> phiHostData(24 * 2048, 1.0);
    std::vector<float> alphaHostData(3, 1.0);
    std::vector<short> gradHInHostData(1024 * 512, 1.0);
    std::vector<float> gradHPostHostData(1024 * 4, 1.0);
    std::vector<float> gradHResHostData(1024 * 4 * 4, 1.0);
    std::vector<float> invRmsHostData(1024, 1.0);
    std::vector<float> hMixHostData(1024 * 24, 1.0);
    std::vector<float> hPreHostData(1024 * 4, 1.0);
    std::vector<float> hPostHostData(1024 * 4, 1.0);
    std::vector<float> gammaHostData(4 * 512, 1.0);
    std::vector<short> gradXHostData(1024 * 4 * 512, 0);
    std::vector<float> gradPhiHostData(24 * 2048, 0);
    std::vector<float> gradAlphaHostData(3, 0);
    std::vector<float> gradBiasHostData(24, 0);
    std::vector<float> gradGammaHostData(4 * 512, 0);
    std::vector<short> gradXPostOptionalHostData(1024 * 4 * 512, 0);

    ret = CreateAclTensor(xHostData, xShape, &xDeviceAddr, aclDataType::ACL_BF16, &x);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(phiHostData, phiShape, &phiDeviceAddr, aclDataType::ACL_FLOAT, &phi);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(alphaHostData, alphaShape, &alphaDeviceAddr, aclDataType::ACL_FLOAT, &alpha);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(gradHInHostData, gradHInShape, &gradHInDeviceAddr, aclDataType::ACL_BF16, &gradHIn);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(gradHPostHostData, gradHPostShape, &gradHPostDeviceAddr, aclDataType::ACL_FLOAT, &gradHPost);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(gradHResHostData, gradHResShape, &gradHResDeviceAddr, aclDataType::ACL_FLOAT, &gradHRes);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(invRmsHostData, invRmsShape, &invRmsDeviceAddr, aclDataType::ACL_FLOAT, &invRms);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(hMixHostData, hMixShape, &hMixDeviceAddr, aclDataType::ACL_FLOAT, &hMix);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(hPreHostData, hPreShape, &hPreDeviceAddr, aclDataType::ACL_FLOAT, &hPre);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(hPostHostData, hPostShape, &hPostDeviceAddr, aclDataType::ACL_FLOAT, &hPost);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(gammaHostData, gammaShape, &gammaDeviceAddr, aclDataType::ACL_FLOAT, &gamma);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(gradXHostData, gradXShape, &gradXDeviceAddr, aclDataType::ACL_BF16, &gradX);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(gradPhiHostData, gradPhiShape, &gradPhiDeviceAddr, aclDataType::ACL_FLOAT, &gradPhi);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(gradAlphaHostData, gradAlphaShape, &gradAlphaDeviceAddr, aclDataType::ACL_FLOAT, &gradAlpha);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(gradBiasHostData, gradBiasShape, &gradBiasDeviceAddr, aclDataType::ACL_FLOAT, &gradBias);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(gradGammaHostData, gradGammaShape, &gradGammaDeviceAddr, aclDataType::ACL_FLOAT, &gradGamma);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(gradXPostOptionalHostData, gradXPostOptionalShape, &gradXPostOptionalDeviceAddr,
                          aclDataType::ACL_BF16, &gradXPostOptional);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    float hc_eps = 1e-6;

    // 3. 调用CANN算子库API,需要修改为具体的Api名称
    uint64_t workspaceSize = 80 * 1024 * 1024;
    aclOpExecutor *executor;

    // 调用aclnnMhcPreBackward第一段接口
    ret = aclnnMhcPreBackwardGetWorkspaceSize(x, phi, alpha, gradHIn, gradHPost, gradHRes, invRms, hMix, hPre, hPost,
                                              gamma, gradXPostOptional, hc_eps, gradX, gradPhi, gradAlpha, gradBias,
                                              gradGamma, &workspaceSize, &executor);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnMhcPreBackwardGetWorkspaceSize failed. ERROR: %d\n", ret);
              return ret);

    // 根据第一段接口计算出的workspaceSize申请device内存
    void *workspaceAddr = nullptr;
    if (workspaceSize > 0) {
        ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret);
    }

    // 调用aclnnMhcPreBackward第二段接口
    ret = aclnnMhcPreBackward(workspaceAddr, workspaceSize, executor, stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnMhcPreBackward failed. ERROR: %d\n", ret); return ret);

    // 4. (固定写法)同步等待任务执行结束
    ret = aclrtSynchronizeStream(stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret);

    // 5. 获取输出的值,将device侧内存上的结果拷贝至host侧,需要根据具体API的接口定义修改
    PrintOutResult(gradXShape, &gradXDeviceAddr, "gradX", sizeof(short));
    PrintOutResult(gradPhiShape, &gradPhiDeviceAddr, "gradPhi");
    PrintOutResult(gradAlphaShape, &gradAlphaDeviceAddr, "gradAlpha");
    PrintOutResult(gradBiasShape, &gradBiasDeviceAddr, "gradBias");
    PrintOutResult(gradGammaShape, &gradGammaDeviceAddr, "gradGamma");

    // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改
    aclDestroyTensor(x);
    aclDestroyTensor(phi);
    aclDestroyTensor(alpha);
    aclDestroyTensor(gradHIn);
    aclDestroyTensor(gradHPost);
    aclDestroyTensor(gradHRes);
    aclDestroyTensor(invRms);
    aclDestroyTensor(hMix);
    aclDestroyTensor(hPre);
    aclDestroyTensor(hPost);
    aclDestroyTensor(gamma);
    aclDestroyTensor(gradX);
    aclDestroyTensor(gradPhi);
    aclDestroyTensor(gradAlpha);
    aclDestroyTensor(gradBias);
    aclDestroyTensor(gradGamma);
    aclDestroyTensor(gradXPostOptional);

    // 7. 释放device资源
    aclrtFree(xDeviceAddr);
    aclrtFree(phiDeviceAddr);
    aclrtFree(alphaDeviceAddr);
    aclrtFree(gradHInDeviceAddr);
    aclrtFree(gradHPostDeviceAddr);
    aclrtFree(gradHResDeviceAddr);
    aclrtFree(invRmsDeviceAddr);
    aclrtFree(hMixDeviceAddr);
    aclrtFree(hPreDeviceAddr);
    aclrtFree(hPostDeviceAddr);
    aclrtFree(gammaDeviceAddr);
    aclrtFree(gradXDeviceAddr);
    aclrtFree(gradPhiDeviceAddr);
    aclrtFree(gradAlphaDeviceAddr);
    aclrtFree(gradBiasDeviceAddr);
    aclrtFree(gradGammaDeviceAddr);
    aclrtFree(gradXPostOptionalDeviceAddr);
    if (workspaceSize > 0) {
        aclrtFree(workspaceAddr);
    }
    aclrtDestroyStream(stream);
    aclrtDestroyContext(context);
    aclrtResetDevice(deviceId);
    aclFinalize();

    return 0;
}