MhcPreSinkhornBackward

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能MhcPreSinkhornBackwardMhcPreSinkhorn 的反向算子,用于计算 mHC(Manifold-Constrained Hyper-Connections)结构中 Sinkhorn 变换的反向梯度传播。

  • 主要输出gradXgradPhigradAlphagradBias

  • 前向缓存依赖hPrehcBeforeNorminvRmssumOutnormOut

  • 计算公式

    gradX=∇x(MhcPreSinkhorn(x,ϕ,α,bias))gradPhi=∇ϕ(MhcPreSinkhorn(x,ϕ,α,bias))gradAlpha=∇α(MhcPreSinkhorn(x,ϕ,α,bias))gradBias=∇bias(MhcPreSinkhorn(x,ϕ,α,bias))\begin{aligned} gradX &= \nabla_{x}(\text{MhcPreSinkhorn}(x, \phi, \alpha, \text{bias})) \\ gradPhi &= \nabla_{\phi}(\text{MhcPreSinkhorn}(x, \phi, \alpha, \text{bias})) \\ gradAlpha &= \nabla_{\alpha}(\text{MhcPreSinkhorn}(x, \phi, \alpha, \text{bias})) \\ gradBias &= \nabla_{\text{bias}}(\text{MhcPreSinkhorn}(x, \phi, \alpha, \text{bias})) \end{aligned}

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
gradHin 输入 输出h_in的梯度。 BFLOAT16,FLOAT16 ND
gradHPost 输入 输出h_post的梯度。 FLOAT32 ND
gradHRes 输入 输出h_res的梯度。 FLOAT32 ND
x 输入 前向输入x。 BFLOAT16,FLOAT16 ND
phi 输入 前向参数phi。 FLOAT32 ND
alpha 输入 前向参数alpha。 FLOAT32 ND
bias 输入 前向参数bias。 FLOAT32 ND
hPre 输入 前向保存的中间结果h_pre。 FLOAT32 ND
hcBeforeNorm 输入 前向保存的中间结果hc_before_norm。 FLOAT32 ND
invRms 输入 前向保存的中间结果inv_rms。 FLOAT32 ND
sumOut 输入 Sinkhorn变换正向计算保存的中间sum结果。 FLOAT32 ND
normOut 输入 Sinkhorn变换正向计算保存的中间norm结果。 FLOAT32 ND
hcEps 属性 数值稳定性参数,默认值1e-6。 DOUBLE -
gradX 输出 输入x的梯度。 BFLOAT16,FLOAT16 ND
gradPhi 输出 参数phi的梯度。 FLOAT32 ND
gradAlpha 输出 参数alpha的梯度。 FLOAT32 ND
gradBias 输出 参数bias的梯度。 FLOAT32 ND

约束说明

sumOut的shape记为(2*sk_iter_count,B,S,N) x的shape记为(B,S,N,C)

  • sk_iter_count 当前仅支持 20
  • N 当前仅支持 4
  • C 大于0 小于 100000 且可以被128整除。

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_mhc_pre.cpp 通过aclnnMhcPreSinkhornBackward 接口方式调用MhcPreSinkhornBackward算子。