MhcPre

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT x
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:基于一系列计算得到MHC架构中hidden层的Hres′\mathbf{H}'_{\text{res}}Hpost\mathbf{H}_{\text{post}}投影矩阵以及Attention或MLP层的输入矩阵hin\mathbf{h}_{\text{in}}。对Hres′\mathbf{H}'_{\text{res}}矩阵执行Sinkhorn迭代归一化变换,最终得到双随机矩阵Hres\mathbf{H}_{\text{res}};支持输出中间计算结果,用于反向梯度计算。包括sigmoid计算之后的Hlpre\mathbf{H^{pre}_l}矩阵、xl′⃗\vec{x^{'}_{l}}φ\mathbf{\varphi}矩阵乘的结果,输入x的RmsNorm结果xl′⃗\mathbf{\vec{x^{'}_{l}}}、迭代过程中的中间归一化结果和normOut\mathbf{normOut}和求和结果sumOut\mathbf{sumOut}

  • 计算公式

    xl′⃗=11d∑dim⁡=−2,keepdim=Truexi2+ϵHlpre=αlpre⋅(xl′⃗φlpre)+blpreHlpost=αlpost⋅(xl′⃗φlpost)+blpostHlres=αlres⋅(xl′⃗φlres)+blresHlpre=σ(Hlpre)Hlpost=2σ(Hlpost)hin=xl⃗Hlpre\begin{aligned} \vec{x^{'}_{l}} &= \frac{1}{\sqrt{\frac{1}{d} \sum_{\dim=-2,\text{keepdim}=\text{True}} x_i^2 + \epsilon}}\\ H^{pre}_l &= \alpha^{pre}_{l} ·(\vec{x^{'}_{l}}\varphi^{pre}_{l}) + b^{pre}_{l}\\ H^{post}_l &= \alpha^{post}_{l} ·(\vec{x^{'}_{l}}\varphi^{post}_{l}) + b^{post}_{l}\\ H^{res}_l &= \alpha^{res}_{l} ·(\vec{x^{'}_{l}}\varphi^{res}_{l}) + b^{res}_{l}\\ H^{pre}_l &= \sigma (H^{pre}_{l})\\ H^{post}_l &= 2\sigma (H^{post}_{l})\\ h_{in} &=\vec{x_{l}}H^{pre}_l \end{aligned}

    • Hlres\mathbf{H^{res}_l}作为输入,Sinkhorn变换共执行numIters\mathbf{numIters}次迭代,迭代过程中生成中间归一化结果normOut[k]\mathbf{normOut}[k]和求和结果sumOut[k]\mathbf{sumOut}[k],最终输出最后一次迭代的normOut\mathbf{normOut}作为变换结果。

      第一次迭代(初始化):

      normOut[0]=softmax(Hlres,dim⁡=−1)+ϵ,sumOut[1]=∑dim⁡=−2,keepdim=TruenormOut[0]+ϵ,normOut[1]=normOut[0]sum_out[1],\begin{aligned} \mathbf{normOut}[0] &= \text{softmax}(\mathbf{H^{res}_l}, \dim=-1) + \epsilon, \\ \mathbf{sumOut}[1] &= \sum_{\dim=-2,\text{keepdim}=\text{True}} \mathbf{normOut}[0] + \epsilon, \\ \mathbf{normOut}[1] &= \frac{\mathbf{normOut}[0]}{\mathbf{sum\_out}[1]}, \\ \end{aligned}

      ii次迭代(i=1,2,…,(num_iters−1)i = 1, 2, \dots, \mathbf({num\_iters}-1)):

      sumOut[2i]=∑dim⁡=−1,keepdim=TruenormOut[2i−1]+ϵ,normOut[2i]=normOut[2i−1]sum_out[2i],sumOut[2i+1]=∑dim⁡=−2,keepdim=TruenormOut[2i]+ϵ,normOut[2i+1]=normOut[2i]sum_out[2i+1],\begin{aligned} \mathbf{sumOut}[2i] &= \sum_{\dim=-1,\text{keepdim}=\text{True}} \mathbf{normOut}[2i-1] + \epsilon, \\ \mathbf{normOut}[2i] &= \frac{\mathbf{normOut}[2i-1]}{\mathbf{sum\_out}[2i]}, \\ \mathbf{sumOut}[2i+1] &= \sum_{\dim=-2,\text{keepdim}=\text{True}} \mathbf{normOut}[2i] + \epsilon, \\ \mathbf{normOut}[2i+1] &= \frac{\mathbf{normOut}[2i]}{\mathbf{sum\_out}[2i+1]}, \\ \end{aligned}

    • 最终输出

    normOut[2×num_iters−1]\mathbf{normOut}[2 \times \mathbf{num\_iters} - 1]

    sumOut[2×num_iters−1]\mathbf{sumOut}[2 \times \mathbf{num\_iters} - 1]

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x 输入 待计算数据,表示网络中mHC层的输入数据。 BFLOAT16, FLOAT16 ND
phi 输入 mHC的参数矩阵。 FLOAT32 ND
alpha 输入 mHC的缩放参数。 FLOAT32 ND
bias 输入 mHC的bias参数。 FLOAT32 ND
hcMult 可选输入 残差流数量,HC维度大小,当前仅支持4。 INT32 -
numIters 可选输入 表示sinkhorn算法迭代次数,当前仅支持20。 INT32 -
hcEps 可选输入 h_pre的sigmoid后的eps参数。 DOUBLE -
normEps 可选输入 RmsNorm的防除零参数。 DOUBLE -
needGrad 可选输入 是否需要输出额外属性。 BOOL -
hIn 输出 输出的h_in作为Attention/MLP层的输入。 BFLOAT16, FLOAT16 ND
hPost 输出 输出的mHC的h_post变换矩阵。 FLOAT32 ND
hRes 输出 输出的mHC的h_res变换矩阵。 FLOAT32 ND
hPre 可选输出 需要反向时输出,做完sigmoid计算之后的hPre矩阵。 FLOAT32 ND
hcBeforeNorm 可选输出 需要反向时输出,x与phi矩阵乘的结果。 FLOAT32 ND
invRms 可选输出 需要反向时输出,RmsNorm计算得到的1/r。 FLOAT32 ND
sumOut 可选输出 需要反向时输出,每一次迭代的colSum/rowSum结果。 FLOAT32 ND
normOut 可选输出 需要反向时输出,每一次colSum/rowSum迭代后的comb结果。 FLOAT32 ND

约束说明

  • n目前支持4。
  • 输入x的最后一维需要满足128对齐。

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_mhc_pre_sinkhorn 通过aclnnMhcPreSinkHorn接口方式调用MhcPreSinkHorn算子。