aclnnMhcPreBackward
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
接口功能:MhcPreBackward是MhcPre的反向算子,MhcPre算子基于一系列计算得到mHC(Manifold-Constrained Hyper-Connections)架构中的HresH^{res}和HpostH^{post}投影矩阵以及Atten或MLP层的输入矩阵hinh^{in}
-
计算公式:
-
输出组合梯度计算
- 正向公式:
H_in=∑i=1Nx[B,S,i,:]⋅H_pren[B,S,i]H\_in = \sum_{i=1}^{N} x[{B,S,i,:}] · H\_pre_n[B,S,i]
- 反向计算:
H_pre_grad=Reduce(H_in_grad.unsqueeze(−2)⊙x,dim=−1)([B,S,N])x_grad_vec3=H_in_grad×H_pre([B,S,N,D])\begin{aligned} H\_pre\_grad &= \text{Reduce}\left(H\_in\_grad.\text{unsqueeze}(-2) \odot x, \text{dim}=-1\right) \quad ([B,S,N]) \\ x\_grad\_vec3 &= H\_in\_grad \times H\_pre \quad ([B,S,N,D]) \end{aligned}
- 正向公式:
-
Sigmoid门控反向(H_pre)
- 正向公式:
H_pre=Sigmoid(α_pre∗H_pre_1+bias_pre)+hc_epsH\_pre = \text{Sigmoid}(\alpha\_pre * H\_pre\_1 + bias\_pre) + hc\_eps
- 反向计算:
s=H_pre−hc_epsH_pre_2_grad=H_pre_grad⊙s⊙(1−s)H_pre_1_grad=H_pre_2_grad⋅α_preα_pre_grad=∑b,s,nB,S,N(H_pre_2_grad⋅H_pre_1)bias_pre_grad=∑b,sB,SH_pre_2_grad([N])\begin{aligned} s &= H\_pre - hc\_eps \\ H\_pre\_2\_grad &= H\_pre\_grad \odot s \odot (1 - s) \\ H\_pre\_1\_grad &= H\_pre\_2\_grad \cdot \alpha\_pre \\ \alpha\_pre\_grad &= \sum_{b,s,n}^{B,S,N} \left(H\_pre\_2\_grad \cdot H\_pre\_1\right) \\ bias\_pre\_grad &= \sum_{b,s}^{B,S} H\_pre\_2\_grad \quad ([N]) \end{aligned}
-
Sigmoid门控反向(H_post)
- 正向公式:
H_post=Sigmoid(α_post∗H_post_1+bias_post)∗2H\_post = \text{Sigmoid}(\alpha\_post * H\_post\_1 + bias\_post) * 2
- 反向计算:
H_post_2_grad=H_post_grad⊙(H_post⋅(1−H_post2))H_post_1_grad=H_post_2_grad⋅αpostαpost_grad=∑b,s,nB,S,N(H_post_2_grad⋅Hpost_1)bias_post_grad=∑b,sB,SH_post_2_grad([N])\begin{aligned} H\_post\_2\_grad &= H\_post\_grad \odot \left(H\_post \cdot \left(1 - \frac{H\_post}{2}\right)\right) \\ H\_post\_1\_grad &= H\_post\_2\_grad \cdot \alpha_{post} \\ \alpha_{post\_grad} &= \sum_{b,s,n}^{B,S,N} \left(H\_post\_2\_grad \cdot H_{post\_1}\right) \\ bias\_post\_grad &= \sum_{b,s}^{B,S} H\_post\_2\_grad \quad ([N]) \end{aligned}
-
残差连接反向(H_res)
- 正向公式:
H_res=α_res∗H_res_1+bias_resH\_res = \alpha\_res * H\_res\_1 + bias\_res
- 反向计算:
H_res_2_grad=H_res_grad⋅αres([B,S,N,N])α_res_grad=∑b,s,i,jB,S,N,N(H_res_grad⋅H_res_2)bias_res_grad=∑b,sB,SH_res_grad([N,N])H_res_1_grad=Reshape(H_res_2_grad)([B,S,N2])\begin{aligned} H\_res\_2\_grad &= H\_res\_grad \cdot \alpha_{res} \quad ([B,S,N,N]) \\ \alpha\_res\_grad &= \sum_{b,s,i,j}^{B,S,N,N} \left(H\_res\_grad \cdot H\_res\_2\right) \\ bias\_res\_grad &= \sum_{b,s}^{B,S} H\_res\_grad \quad ([N,N]) \\ H\_res\_1\_grad &= \text{Reshape}(H\_res\_2\_grad) \quad ([B,S,N^2]) \end{aligned}
-
RMSNorm Fusion反向
- 正向公式:
H_mix_tmp=H_mix∗inv_rmsH\_mix\_tmp = H\_mix * inv\_rms
- 反向计算:
H_mix_tmp_grad=Concat(H_pre_1_grad,H_post_1_grad,H_res_1_grad)([B,S,2N+N2])H_mix_grad=H_mix_tmp_grad⋅inv_rmsinv_rmsgrad=∑last_dim(H_mix_tmp_grad⋅H_mix)([B,S,1])\begin{aligned} H\_mix\_tmp\_grad &= \text{Concat}(H\_pre\_1\_grad, H\_post\_1\_grad, H\_res\_1\_grad) \quad ([B,S,2N+N^2]) \\ H\_mix\_grad &= H\_mix\_tmp\_grad \cdot inv\_rms \\ inv\_rms_{grad} &= \sum_{\text{last\_dim}} \left(H\_mix\_tmp\_grad \cdot H\_mix\right) \quad ([B,S,1]) \end{aligned}
-
矩阵乘法反向
- 正向公式:
H_mix=x_rs@phiTH\_mix = x\_rs @ phi^T
x_rs=x∗gammax\_rs = x * gamma
- 反向计算:
x_rs_grad=H_mix_grad@phi([B,S,ND])X=Reshape(x_rs,[B⋅S,ND])G=Reshape(H_mix_grad,[B⋅S,2N+N2])phigrad=GT@X([2N+N2,ND])\begin{aligned} x\_rs\_grad &= H\_mix\_grad @ phi \quad ([B,S,ND]) \\ X &= \text{Reshape}(x\_rs, [B\cdot S, ND]) \\ G &= \text{Reshape}(H\_mix\_grad, [B\cdot S, 2N+N^2]) \\ phi_{grad} &= G^T @ X \quad ([2N+N^2, ND]) \end{aligned}
-
特征缩放反向
- 正向公式:
x_rs=x∗gammax\_rs = x * gamma
- 反向计算:
x_grad_mm=x_rs_grad∗gammagamma_grad=∑b=1B∑s=1S(x∗x_rs_grad)([N,D])\begin{aligned} x\_grad\_mm &= x\_rs\_grad * gamma \\ gamma\_grad &= \sum_{b=1}^{B}\sum_{s=1}^{S} (x * x\_rs\_grad)\quad ([N,D]) \end{aligned}
-
RMS归一化梯度计算
- 正向公式:
inv_rms=11n∑i=1nxi2+eps,其中 n=N∗Dinv\_rms = \frac{1}{\sqrt{\frac{1}{n}\sum_{i=1}^{n}x_i^2 + eps}}, \quad 其中\ n = N * D
- 反向计算:
x_rs_grad_inv=−(inv_rms_grad⋅inv_rms3N∗D)⋅x_rsx_rs_grad=x_grad_mm+x_rs_grad_invx_grad_vec1=Reshape(x_rs_grad,[B,S,N,D])x_grad=x_grad_vec3+x_grad_vec1\begin{aligned} x\_rs\_grad\_inv &= - \left(\frac{inv\_rms\_grad \cdot {inv\_rms}^3}{N*D}\right) \cdot x\_rs \\ x\_rs\_grad &= x\_grad\_mm + x\_rs\_grad\_inv \\ x\_grad\_vec1 &= \text{Reshape}(x\_rs\_grad, [B,S,N,D]) \\ x\_grad &= x\_grad\_vec3 + x\_grad\_vec1 \end{aligned}
- 正向公式:
-
融合mhc_post的grad_x相加操作
x_grad=x_grad+grad_x_post\begin{aligned} x\_grad &= x\_grad + grad\_x\_post \end{aligned}
-
函数原型
每个算子分为两段式接口,必须先调用“aclnnMhcPreBackwardGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnMhcPreBackward”接口执行计算。
aclnnStatus aclnnMhcPreBackwardGetWorkspaceSize(
const aclTensor *x,
const aclTensor *phi,
const aclTensor *alpha,
const aclTensor *gradHIn,
const aclTensor *gradHPost,
const aclTensor *gradHRes,
const aclTensor *invRms,
const aclTensor *hMix,
const aclTensor *hPre,
const aclTensor *hPost,
const aclTensor *gamma,
const aclTensor *gradXPostOptional,
float hcEps,
const aclTensor *gradX,
const aclTensor *gradPhi,
const aclTensor *gradAlpha,
const aclTensor *gradBias,
const aclTensor *gradGamma,
uint64_t *workspaceSize,
aclOpExecutor **executor)
aclnnStatus aclnnMhcPreBackward(
void *workspace,
uint64_t workspaceSize,
aclOpExecutor *executor,
aclrtStream stream)
aclnnMhcPreBackwardGetWorkspaceSize
-
参数说明:
参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor x 输入 待计算数据,表示网络中mHC层的输入数据。 - 不支持空Tensor。
BFLOAT16、FLOAT16 ND (B,S,N,D)、(T,N,D)
B:支持泛化;S:支持泛化;T:B*S。√ phi 输入 mHC的参数矩阵。 - 不支持空Tensor。
FLOAT32 ND (2N+N*N,N*D)
N:与x的N保持一致;D:与x的D保持一致√ alpha 输入 mHC的缩放参数alpha。 - 不支持空Tensor。
FLOAT32 - (3) - gradHIn 输入 hIn作为Atten/MLP层的输入。正向输出hIn对应的梯度。 - 不支持空Tensor。
BFLOAT16、FLOAT16 ND (B,S,D)、(T,D)
B:与x的B保持一致;S:与x的S保持一致;T:B*S;D:与x的D保持一致。√ gradHPost 输入 正向输出hPost对应的梯度。 - 不支持空Tensor。
FLOAT32 ND (B,S,N)、(T,N)
B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致。√ gradHRes 输入 正向输出hRes对应的梯度。 - 不支持空Tensor。
FLOAT32 ND (B,S,N,N)、(T,N,N)
B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致。√ invRms 输入 正向RmsNorm计算的invRms。 - 不支持空Tensor。
FLOAT32 ND (B,S)、(T)
B:与x的B保持一致;S:与x的S保持一致;T:B*S。√ hMix 输入 正向计算流x@phi的结果 - 不支持空Tensor。
FLOAT32 ND (B,S,2N+N*N)、(T,2N+N*N)
B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致。√ hPre 输入 正向sigmoid计算之后的hPre矩阵 - 不支持空Tensor。
FLOAT32 ND (B,S,N)、(T,N)
B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致。√ hPost 输入 正向的hPost输出 - 不支持空Tensor。
FLOAT32 ND (B,S,N)、(T,N)
B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致。√ gamma 可选输入 RmsNorm的缩放系数gamma - 不支持空Tensor。
- 如果传入nullptr,则表示全1的tensor。
FLOAT32 ND (N,D)
N:与x的N保持一致;D:与x的D保持一致。√ gradXPostOptional 可选输入 post反向输出的gradX - 不支持空Tensor。
- 如果传入nullptr,则表示全0的tensor。
BFLOAT16、FLOAT16 ND (B,S,N,D)、(T,N,D)
B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致;D:与x的D保持一致。√ hcEps 可选输入 HPre的sigmoid后的eps参数 - 建议值为1e-6
FLOAT32 - - - gradX 输出 x对应的梯度。 - 与输入x的维度、数据类型保持一致
BFLOAT16、FLOAT16 ND (B,S,N,D)、(T,N,D)
B:与x的B保持一致;S:与x的S保持一致;T:B*S;N:与x的N保持一致;D:与x的D保持一致。√ gradPhi 输出 phi对应的梯度。 - 与输入phi的维度保持一致
FLOAT32 ND (2N+N*N,N*D)
N:与x的N保持一致;D:与x的D保持一致。√ gradAlpha 输出 alpha对应的梯度。 - 与输入alpha维度保持一致。
FLOAT32 ND (3) - gradBias 输出 bias对应的梯度。 - FLOAT32 ND (2N+N*N)
N:与x的N保持一致。- gradGamma 可选输出 gamma对应的梯度。 - 当输入gamma不为nullptr时,此变量才会输出。
- 与输入gamma的Shape维度保持一致。
FLOAT32 ND (N,D)
N:与x的N保持一致;D:与x的D保持一致。√ workspaceSize 输出 返回需要在Device侧申请的workspace大小。 - - - - - executor 输出 返回op执行器,包含了算子计算流程。 - - - - - -
返回值
返回aclnnStatus状态码,具体参见aclnn返回码。
第一段接口完成入参校验,出现以下场景时报错:
返回值 错误码 描述 ACLNN_ERR_PARAM_NULLPTR 161001 必选参数或者输出是空指针。 ACLNN_ERR_PARAM_INVALID 161002 输入变量,x、phi、gamma、alpha的数据类型和数据格式不在支持的范围内。 ACLNN_ERR_RUNTIME_ERROR 361001 API内部调用npu runtime的接口异常。
aclnnMhcPreBackward
-
参数说明:
参数名 输入/输出 描述 workspace 输入 在Device侧申请的workspace内存地址。 workspaceSize 输入 在Device侧申请的workspace大小,由第一段接口aclnnMhcPreBackwardGetWorkspaceSize获取。 executor 输入 op执行器,包含了算子计算流程。 stream 输入 指定执行任务的Stream流。 -
返回值
返回aclnnStatus状态码,具体参见aclnn返回码。
约束说明
-
确定性计算:
- aclnnMhcPreBackward默认采用确定性实现。
-
规格约束
规格项 规格 规格说明 n 4、 6、8 n值目前支持4、 6、 8 D 1~16384 D支持1~16384范围以内且64元素对齐
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。
#include <iostream>
#include <vector>
#include <numeric>
#include "acl/acl.h"
#include "aclnnop/aclnn_mhc_pre_backward.h"
#define CHECK_RET(cond, return_expr) \
do { \
if (!(cond)) { \
return_expr; \
} \
} while (0)
#define LOG_PRINT(message, ...) \
do { \
printf(message, ##__VA_ARGS__); \
} while (0)
int64_t GetShapeSize(const std::vector<int64_t> &shape)
{
int64_t shapeSize = 1;
for (auto i : shape) {
shapeSize *= i;
}
return shapeSize;
}
void PrintOutResult(std::vector<int64_t> &shape, void **deviceAddr, const char *name, size_t elemSize = sizeof(float))
{
auto size = GetShapeSize(shape);
size_t copyBytes = size * elemSize;
if (elemSize == 2) {
// BF16: read raw bytes, convert to float for display
std::vector<uint16_t> rawData(size, 0);
auto ret = aclrtMemcpy(rawData.data(), copyBytes, *deviceAddr, copyBytes, ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return);
LOG_PRINT("%s result (first 10 elements):\n", name);
for (int64_t i = 0; i < std::min(size, (int64_t)10); i++) {
union {
uint32_t i;
float f;
} u;
u.i = (uint32_t)rawData[i] << 16;
LOG_PRINT(" [%ld] = %f\n", i, u.f);
}
} else {
// float32
std::vector<float> resultData(size, 0);
auto ret = aclrtMemcpy(resultData.data(), copyBytes, *deviceAddr, copyBytes, ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return);
LOG_PRINT("%s result (first 10 elements):\n", name);
for (int64_t i = 0; i < std::min(size, (int64_t)10); i++) {
LOG_PRINT(" [%ld] = %f\n", i, resultData[i]);
}
}
}
int Init(int32_t deviceId, aclrtContext *context, aclrtStream *stream)
{
// 固定写法,AscendCL初始化
auto ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
ret = aclrtSetDevice(deviceId);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
ret = aclrtCreateContext(context, deviceId);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateContext failed. ERROR: %d\n", ret); return ret);
ret = aclrtSetCurrentContext(*context);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetCurrentContext failed. ERROR: %d\n", ret); return ret);
ret = aclrtCreateStream(stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
return 0;
}
template <typename T>
int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr,
aclDataType dataType, aclTensor **tensor)
{
auto size = GetShapeSize(shape) * sizeof(T);
// 调用aclrtMalloc申请device侧内存
auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
// 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
// 计算连续tensor的strides
std::vector<int64_t> strides(shape.size(), 1);
for (int64_t i = shape.size() - 2; i >= 0; i--) {
strides[i] = shape[i + 1] * strides[i + 1];
}
// 调用aclCreateTensor接口创建aclTensor
*tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
shape.data(), shape.size(), *deviceAddr);
return 0;
}
int main()
{
// 1. (固定写法)device/context/stream初始化,参考AscendCL对外接口列表
// 根据自己的实际device填写deviceId
int32_t deviceId = 0;
aclrtContext context;
aclrtStream stream;
auto ret = Init(deviceId, &context, &stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret);
// 2. 构造输入与输出,需要根据API的接口自定义构造
std::vector<int64_t> xShape = {1024, 4, 512}; // T, N, D
std::vector<int64_t> phiShape = {24, 2048}; // 2 * N + N * N, N * D
std::vector<int64_t> alphaShape = {3}; // 固定大小
std::vector<int64_t> gradHInShape = {1024, 512}; // T, D
std::vector<int64_t> gradHPostShape = {1024, 4}; // T, N
std::vector<int64_t> gradHResShape = {1024, 4, 4}; // T, N, N
std::vector<int64_t> invRmsShape = {1024}; // T
std::vector<int64_t> hMixShape = {1024, 24}; // T, 2 * N + N * N
std::vector<int64_t> hPreShape = {1024, 4}; // T, N
std::vector<int64_t> hPostShape = {1024, 4}; // T, N
std::vector<int64_t> gammaShape = {4, 512}; // N, D
std::vector<int64_t> gradXShape = {1024, 4, 512}; // T, N, D
std::vector<int64_t> gradPhiShape = {24, 2048}; // 2 * N + N * N, N * D
std::vector<int64_t> gradAlphaShape = {3}; // 固定大小
std::vector<int64_t> gradGammaShape = {4, 512}; // N, D
std::vector<int64_t> gradBiasShape = {24}; // 2*N+N*N
std::vector<int64_t> gradXPostOptionalShape = {1024, 4, 512};
void *xDeviceAddr = nullptr;
void *phiDeviceAddr = nullptr;
void *alphaDeviceAddr = nullptr;
void *gradHInDeviceAddr = nullptr;
void *gradHPostDeviceAddr = nullptr;
void *gradHResDeviceAddr = nullptr;
void *invRmsDeviceAddr = nullptr;
void *hMixDeviceAddr = nullptr;
void *hPreDeviceAddr = nullptr;
void *hPostDeviceAddr = nullptr;
void *gammaDeviceAddr = nullptr;
void *gradXDeviceAddr = nullptr;
void *gradPhiDeviceAddr = nullptr;
void *gradAlphaDeviceAddr = nullptr;
void *gradBiasDeviceAddr = nullptr;
void *gradGammaDeviceAddr = nullptr;
void *gradXPostOptionalDeviceAddr = nullptr;
aclTensor *x = nullptr;
aclTensor *phi = nullptr;
aclTensor *alpha = nullptr;
aclTensor *gradHIn = nullptr;
aclTensor *gradHPost = nullptr;
aclTensor *gradHRes = nullptr;
aclTensor *invRms = nullptr;
aclTensor *hMix = nullptr;
aclTensor *hPre = nullptr;
aclTensor *hPost = nullptr;
aclTensor *gamma = nullptr;
aclTensor *gradX = nullptr;
aclTensor *gradPhi = nullptr;
aclTensor *gradAlpha = nullptr;
aclTensor *gradBias = nullptr;
aclTensor *gradGamma = nullptr;
aclTensor *gradXPostOptional = nullptr;
std::vector<short> xHostData(1024 * 4 * 512, 1.0);
std::vector<float> phiHostData(24 * 2048, 1.0);
std::vector<float> alphaHostData(3, 1.0);
std::vector<short> gradHInHostData(1024 * 512, 1.0);
std::vector<float> gradHPostHostData(1024 * 4, 1.0);
std::vector<float> gradHResHostData(1024 * 4 * 4, 1.0);
std::vector<float> invRmsHostData(1024, 1.0);
std::vector<float> hMixHostData(1024 * 24, 1.0);
std::vector<float> hPreHostData(1024 * 4, 1.0);
std::vector<float> hPostHostData(1024 * 4, 1.0);
std::vector<float> gammaHostData(4 * 512, 1.0);
std::vector<short> gradXHostData(1024 * 4 * 512, 0);
std::vector<float> gradPhiHostData(24 * 2048, 0);
std::vector<float> gradAlphaHostData(3, 0);
std::vector<float> gradBiasHostData(24, 0);
std::vector<float> gradGammaHostData(4 * 512, 0);
std::vector<short> gradXPostOptionalHostData(1024 * 4 * 512, 0);
ret = CreateAclTensor(xHostData, xShape, &xDeviceAddr, aclDataType::ACL_BF16, &x);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(phiHostData, phiShape, &phiDeviceAddr, aclDataType::ACL_FLOAT, &phi);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(alphaHostData, alphaShape, &alphaDeviceAddr, aclDataType::ACL_FLOAT, &alpha);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gradHInHostData, gradHInShape, &gradHInDeviceAddr, aclDataType::ACL_BF16, &gradHIn);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gradHPostHostData, gradHPostShape, &gradHPostDeviceAddr, aclDataType::ACL_FLOAT, &gradHPost);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gradHResHostData, gradHResShape, &gradHResDeviceAddr, aclDataType::ACL_FLOAT, &gradHRes);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(invRmsHostData, invRmsShape, &invRmsDeviceAddr, aclDataType::ACL_FLOAT, &invRms);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(hMixHostData, hMixShape, &hMixDeviceAddr, aclDataType::ACL_FLOAT, &hMix);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(hPreHostData, hPreShape, &hPreDeviceAddr, aclDataType::ACL_FLOAT, &hPre);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(hPostHostData, hPostShape, &hPostDeviceAddr, aclDataType::ACL_FLOAT, &hPost);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gammaHostData, gammaShape, &gammaDeviceAddr, aclDataType::ACL_FLOAT, &gamma);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gradXHostData, gradXShape, &gradXDeviceAddr, aclDataType::ACL_BF16, &gradX);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gradPhiHostData, gradPhiShape, &gradPhiDeviceAddr, aclDataType::ACL_FLOAT, &gradPhi);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gradAlphaHostData, gradAlphaShape, &gradAlphaDeviceAddr, aclDataType::ACL_FLOAT, &gradAlpha);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gradBiasHostData, gradBiasShape, &gradBiasDeviceAddr, aclDataType::ACL_FLOAT, &gradBias);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gradGammaHostData, gradGammaShape, &gradGammaDeviceAddr, aclDataType::ACL_FLOAT, &gradGamma);
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = CreateAclTensor(gradXPostOptionalHostData, gradXPostOptionalShape, &gradXPostOptionalDeviceAddr,
aclDataType::ACL_BF16, &gradXPostOptional);
CHECK_RET(ret == ACL_SUCCESS, return ret);
float hc_eps = 1e-6;
// 3. 调用CANN算子库API,需要修改为具体的Api名称
uint64_t workspaceSize = 80 * 1024 * 1024;
aclOpExecutor *executor;
// 调用aclnnMhcPreBackward第一段接口
ret = aclnnMhcPreBackwardGetWorkspaceSize(x, phi, alpha, gradHIn, gradHPost, gradHRes, invRms, hMix, hPre, hPost,
gamma, gradXPostOptional, hc_eps, gradX, gradPhi, gradAlpha, gradBias,
gradGamma, &workspaceSize, &executor);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnMhcPreBackwardGetWorkspaceSize failed. ERROR: %d\n", ret);
return ret);
// 根据第一段接口计算出的workspaceSize申请device内存
void *workspaceAddr = nullptr;
if (workspaceSize > 0) {
ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret);
}
// 调用aclnnMhcPreBackward第二段接口
ret = aclnnMhcPreBackward(workspaceAddr, workspaceSize, executor, stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnMhcPreBackward failed. ERROR: %d\n", ret); return ret);
// 4. (固定写法)同步等待任务执行结束
ret = aclrtSynchronizeStream(stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret);
// 5. 获取输出的值,将device侧内存上的结果拷贝至host侧,需要根据具体API的接口定义修改
PrintOutResult(gradXShape, &gradXDeviceAddr, "gradX", sizeof(short));
PrintOutResult(gradPhiShape, &gradPhiDeviceAddr, "gradPhi");
PrintOutResult(gradAlphaShape, &gradAlphaDeviceAddr, "gradAlpha");
PrintOutResult(gradBiasShape, &gradBiasDeviceAddr, "gradBias");
PrintOutResult(gradGammaShape, &gradGammaDeviceAddr, "gradGamma");
// 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改
aclDestroyTensor(x);
aclDestroyTensor(phi);
aclDestroyTensor(alpha);
aclDestroyTensor(gradHIn);
aclDestroyTensor(gradHPost);
aclDestroyTensor(gradHRes);
aclDestroyTensor(invRms);
aclDestroyTensor(hMix);
aclDestroyTensor(hPre);
aclDestroyTensor(hPost);
aclDestroyTensor(gamma);
aclDestroyTensor(gradX);
aclDestroyTensor(gradPhi);
aclDestroyTensor(gradAlpha);
aclDestroyTensor(gradBias);
aclDestroyTensor(gradGamma);
aclDestroyTensor(gradXPostOptional);
// 7. 释放device资源
aclrtFree(xDeviceAddr);
aclrtFree(phiDeviceAddr);
aclrtFree(alphaDeviceAddr);
aclrtFree(gradHInDeviceAddr);
aclrtFree(gradHPostDeviceAddr);
aclrtFree(gradHResDeviceAddr);
aclrtFree(invRmsDeviceAddr);
aclrtFree(hMixDeviceAddr);
aclrtFree(hPreDeviceAddr);
aclrtFree(hPostDeviceAddr);
aclrtFree(gammaDeviceAddr);
aclrtFree(gradXDeviceAddr);
aclrtFree(gradPhiDeviceAddr);
aclrtFree(gradAlphaDeviceAddr);
aclrtFree(gradBiasDeviceAddr);
aclrtFree(gradGammaDeviceAddr);
aclrtFree(gradXPostOptionalDeviceAddr);
if (workspaceSize > 0) {
aclrtFree(workspaceAddr);
}
aclrtDestroyStream(stream);
aclrtDestroyContext(context);
aclrtResetDevice(deviceId);
aclFinalize();
return 0;
}