MhcPre

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:基于一系列计算得到MHC架构中hidden层的HresH^{res}HpostH^{post}投影矩阵以及Attention或MLP层的输入矩阵hinh^{in}

  • 计算公式

xl′⃗=RMSNorm(xl⃗)Hlpre=αlpre⋅(xl′⃗φlpre)+blpreHlpost=αlpost⋅(xl′⃗φlpost)+blpostHlres=αlres⋅(xl′⃗φlres)+blresHlpre=σ(Hlpre)Hlpost=2σ(Hlpost)hin=xl′⃗Hlpre\begin{aligned} \vec{x^{'}_{l}} &=RMSNorm(\vec{x_{l}})\\ H^{pre}_l &= \alpha^{pre}_{l} ·(\vec{x^{'}_{l}}\varphi^{pre}_{l}) + b^{pre}_{l}\\ H^{post}_l &= \alpha^{post}_{l} ·(\vec{x^{'}_{l}}\varphi^{post}_{l}) + b^{post}_{l}\\ H^{res}_l &= \alpha^{res}_{l} ·(\vec{x^{'}_{l}}\varphi^{res}_{l}) + b^{res}_{l}\\ H^{pre}_l &= \sigma (H^{pre}_{l})\\ H^{post}_l &= 2\sigma (H^{post}_{l})\\ h_{in} &=\vec{x^{'}_{l}}H^{pre}_l \end{aligned}

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x 输入 待计算数据,表示网络中mHC层的输入数据。 BFLOAT16, FLOAT16 ND
phi 输入 mHC的参数矩阵。 FLOAT32 ND
alpha 输入 mHC的缩放参数。 FLOAT32 -
bias 输入 mHC的bias参数。 FLOAT32 -
gamma 可选输入 表示进行RmsNorm计算的缩放因子。 FLOAT32 ND
out_flag 可选输入 表示是否输出inv_rms/h_mix/h_pre, 默认为0表示不输出,为1表示全输出。 DOUBLE -
norm_eps 可选输入 RmsNorm的防除零参数。 DOUBLE -
hc_eps 可选输入 h_pre的sigmoid后的eps参数。 DOUBLE -
h_in 输出 输出的h_in作为Attention/MLP层的输入。 BFLOAT16, FLOAT16 ND
h_post 输出 输出的mHC的h_post变换矩阵。 FLOAT32 ND
h_res 输出 输出的mHC的h_res变换矩阵(未做sinkhorn变换)。 FLOAT32 ND
inv_rms 可选输出 RmsNorm计算得到的1/r。 FLOAT32 ND
h_mix 可选输出 x与phi矩阵乘的结果。 FLOAT32 ND
h_pre 可选输出 做完sigmoid计算之后的h_pre矩阵。 FLOAT32 ND

约束说明

  • n目前支持4、6、8。
  • D支持1~16384范围以内,需满足D为16对齐。

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_mhc_pre 通过aclnnMhcPre接口方式调用MhcPre算子。