MhcPre
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:基于一系列计算得到MHC架构中hidden层的HresH^{res}和HpostH^{post}投影矩阵以及Attention或MLP层的输入矩阵hinh^{in}。
-
计算公式:
xl′⃗=RMSNorm(xl⃗)Hlpre=αlpre⋅(xl′⃗φlpre)+blpreHlpost=αlpost⋅(xl′⃗φlpost)+blpostHlres=αlres⋅(xl′⃗φlres)+blresHlpre=σ(Hlpre)Hlpost=2σ(Hlpost)hin=xl′⃗Hlpre\begin{aligned} \vec{x^{'}_{l}} &=RMSNorm(\vec{x_{l}})\\ H^{pre}_l &= \alpha^{pre}_{l} ·(\vec{x^{'}_{l}}\varphi^{pre}_{l}) + b^{pre}_{l}\\ H^{post}_l &= \alpha^{post}_{l} ·(\vec{x^{'}_{l}}\varphi^{post}_{l}) + b^{post}_{l}\\ H^{res}_l &= \alpha^{res}_{l} ·(\vec{x^{'}_{l}}\varphi^{res}_{l}) + b^{res}_{l}\\ H^{pre}_l &= \sigma (H^{pre}_{l})\\ H^{post}_l &= 2\sigma (H^{post}_{l})\\ h_{in} &=\vec{x^{'}_{l}}H^{pre}_l \end{aligned}
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x | 输入 | 待计算数据,表示网络中mHC层的输入数据。 | BFLOAT16, FLOAT16 | ND |
| phi | 输入 | mHC的参数矩阵。 | FLOAT32 | ND |
| alpha | 输入 | mHC的缩放参数。 | FLOAT32 | - |
| bias | 输入 | mHC的bias参数。 | FLOAT32 | - |
| gamma | 可选输入 | 表示进行RmsNorm计算的缩放因子。 | FLOAT32 | ND |
| out_flag | 可选输入 | 表示是否输出inv_rms/h_mix/h_pre, 默认为0表示不输出,为1表示全输出。 | DOUBLE | - |
| norm_eps | 可选输入 | RmsNorm的防除零参数。 | DOUBLE | - |
| hc_eps | 可选输入 | h_pre的sigmoid后的eps参数。 | DOUBLE | - |
| h_in | 输出 | 输出的h_in作为Attention/MLP层的输入。 | BFLOAT16, FLOAT16 | ND |
| h_post | 输出 | 输出的mHC的h_post变换矩阵。 | FLOAT32 | ND |
| h_res | 输出 | 输出的mHC的h_res变换矩阵(未做sinkhorn变换)。 | FLOAT32 | ND |
| inv_rms | 可选输出 | RmsNorm计算得到的1/r。 | FLOAT32 | ND |
| h_mix | 可选输出 | x与phi矩阵乘的结果。 | FLOAT32 | ND |
| h_pre | 可选输出 | 做完sigmoid计算之后的h_pre矩阵。 | FLOAT32 | ND |
约束说明
- n目前支持4、6、8。
- D支持1~16384范围以内,需满足D为16对齐。
调用说明
| 调用方式 | 调用样例 | 说明 |
|---|---|---|
| aclnn调用 | test_aclnn_mhc_pre | 通过aclnnMhcPre接口方式调用MhcPre算子。 |