MhcPreBackward

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能MhcPreBackwardMhcPre 的反向算子,用于计算 mHC(Manifold-Constrained Hyper-Connections)结构中的反向梯度。

  • 主要输出gradXgradPhigradAlphagradBias,以及在 gamma != nullptr 时输出 gradGamma

  • 前向缓存依赖invRmshMixhPrehPost

  • 可选输入gammagradXPostOptional

  • 计算公式

    gradX=∇x(MhcPre(x,ϕ,α,γ))gradPhi=∇ϕ(MhcPre(x,ϕ,α,γ))gradAlpha=∇α(MhcPre(x,ϕ,α,γ))gradBias=∇bias(MhcPre(x,ϕ,α,γ))\begin{aligned} gradX &= \nabla_{x}(\text{MhcPre}(x, \phi, \alpha, \gamma)) \\ gradPhi &= \nabla_{\phi}(\text{MhcPre}(x, \phi, \alpha, \gamma)) \\ gradAlpha &= \nabla_{\alpha}(\text{MhcPre}(x, \phi, \alpha, \gamma)) \\ gradBias &= \nabla_{bias}(\text{MhcPre}(x, \phi, \alpha, \gamma)) \end{aligned}

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x 输入 mHC层输入数据。 BFLOAT16、FLOAT16 ND
phi 输入 mHC参数矩阵。 FLOAT32 ND
alpha 输入 mHC缩放参数。 FLOAT32 ND
grad_h_in 输入 对h_in的梯度。 BFLOAT16、FLOAT16 ND
grad_h_post 输入 对h_post的梯度。 FLOAT32 ND
grad_h_res 输入 对h_res的梯度。 FLOAT32 ND
inv_rms 输入 前向缓存的inv_rms。 FLOAT32 ND
h_mix 输入 前向缓存的h_mix。 FLOAT32 ND
h_pre 输入 前向缓存的h_pre。 FLOAT32 ND
h_post 输入 前向缓存的h_post。 FLOAT32 ND
gamma 可选输入 RMSNorm缩放因子。 FLOAT32 ND
grad_x_post 可选输入 来自后续路径的grad_x累加项。 BFLOAT16、FLOAT16 ND
hc_eps 属性 h_pre sigmoid后使用的eps参数。 FLOAT32 -
grad_x 输出 x的梯度。 BFLOAT16、FLOAT16 ND
grad_phi 输出 phi的梯度。 FLOAT32 ND
grad_alpha 输出 alpha的梯度。 FLOAT32 ND
grad_bias 输出 bias整体梯度。 FLOAT32 ND
grad_gamma 可选输出 gamma的梯度。 FLOAT32 ND

约束说明

  • N 当前仅支持 468 三种取值。
  • D 支持 1~16384,且需满足 64 元素对齐。

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_mhc_pre_backward.cpp 通过aclnnMhcPreBackward 接口方式调用MhcPreBackward算子。