MhcSinkhornBackward

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:MhcSinkhornBackward是MhcSinkhorn的反向算子。mHC(Manifold-Constrained Hyper-Connections)架构中的MhcSinkhorn算子对输入矩阵做sinkhorn变换得到双随机矩阵Hres\mathbf{H}_{\text{res}},输出的双随机矩阵的所有元素≥0、每一行之和为1且每一列之和为1 (具有范数保持、组合封闭性和凸组合几何解释三大特性)。对mHC架构中双随机矩阵Hres\mathbf{H}_{\text{res}}矩阵的梯度进行sinkhorn变换的反向计算得到输入Hres′\mathbf{H}'_{\text{res}}的梯度。

  • 计算公式:

    • Sinkhorn-Knopp算法在正向计算中通过 num_iters\mathbf{num\_iters} 次迭代归一化实现双随机投影,在反向传播的迭代计算中:

    • num_iters−1\mathbf{num\_iters}-1 次迭代:

      dot_prod2i+1=∑dim⁡=−2,keepdim=True(gradcurr ⋅ norm_out2i+1),gradcurr←gradcurr−dot_prod2i+1sum_out2i+1,dot_prod2i=∑dim⁡=−1,keepdim=True(gradcurr ⋅ norm_out2i),gradcurr←gradcurr−dot_prod2isum_out2i,\begin{aligned} \mathbf{dot\_prod}_{2i+1} &= \sum_{\dim=-2,\text{keepdim}=\text{True}} (\mathbf{grad}_{curr}\ {⋅}\ \mathbf{norm\_out}_{2i+1}), \\ \mathbf{grad}_{curr} &← \frac{\mathbf{grad}_{curr} - \mathbf{dot\_prod}_{2i+1}}{\mathbf{sum\_out}_{2i+1}}, \\ \mathbf{dot\_prod}_{2i} &= \sum_{\dim=-1,\text{keepdim}=\text{True}} (\mathbf{grad}_{curr}\ {⋅}\ \mathbf{norm\_out}_{2i}), \\ \mathbf{grad}_{curr} &← \frac{\mathbf{grad}_{curr} - \mathbf{dot\_prod}_{2i}}{\mathbf{sum\_out}_{2i}}, \\ \end{aligned}

    • 最后一次迭代:

      dot_prod1=∑dim⁡=−2,keepdim=True(gradcurr ⋅ norm_out1),gradcurr←gradcurr−dot_prod1sum_out1,dot_prod0=∑dim⁡=−1,keepdim=True(gradcurr ⋅ norm_out0),gradinput←(gradcurr−dot_prod0) ⋅ norm_out0\begin{aligned} \mathbf{dot\_prod}_{1} &= \sum_{\dim=-2,\text{keepdim}=\text{True}} (\mathbf{grad}_{curr}\ {⋅}\ \mathbf{norm\_out}_{1}), \\ \mathbf{grad}_{curr} &← \frac{\mathbf{grad}_{curr} - \mathbf{dot\_prod}_{1}}{\mathbf{sum\_out}_{1}}, \\ \mathbf{dot\_prod}_{0} &= \sum_{\dim=-1,\text{keepdim}=\text{True}} (\mathbf{grad}_{curr}\ {⋅}\ \mathbf{norm\_out}_{0}), \\ \mathbf{grad}_{input} &← ({\mathbf{grad}_{curr} - \mathbf{dot\_prod}_{0}})\ {⋅}\ \mathbf{norm\_out}_{0} \\ \end{aligned}

    • 其中:gradcurr\mathbf{grad}_\text{curr} 为初始梯度,gradinput\mathbf{grad}_\text{input} 为输出梯度,norm_outk\mathbf{norm\_out}_\text{k}为第kk次归一化方向向量,sum_outk\mathbf{sum\_out}_\text{k} 为对应的缩放系数。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
grad_y 输入 Sinkhorn变换输出的H_res的梯度。 FLOAT32 ND
norm 输入 Sinkhorn变换正向计算保存的中间norm结果。 FLOAT32 ND
sum 输入 Sinkhorn变换正向计算保存的中间sum结果。 FLOAT32 ND
grad_input 输出 Sinkhorn变换的输入的H_res的梯度。 FLOAT32 ND

约束说明

  • 输入 grad_y 仅支持 3 维 (T,n,n) 或 4 维 (B,S,n,n)。
  • 输入 norm 仅支持 1 维 (2*num_iters*n*align_n*B*S) 或 (2*num_iters*n*align_n*T) 。
  • 输入 sum 仅支持 1 维 (2*num_iters*align_n*B*S) 或 (2*num_iters*align_n*T) 。
  • num_iters:取值范围 1~100,超出则报参数无效。
  • n:仅支持 4、6或8。
  • align_n:固定取值为 8。

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_mhc_sinkhorn_backward 通过aclnnMhcSinkhornBackward接口方式调用MhcSinkhornBackward算子。