MhcPost

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:MhcPost基于一系列计算对mHC架构中上一层输出htouth_{t}^{out}进行Post Mapping,对上一层的输入xlx_l进行Res Mapping,然后对二者进行残差连接,得到下一层的输入xl+1x_{l+1}

  • 计算公式:

    xl+1=(Hlres)T×xl+hlout⊗Htpostx_{l+1} = (H_{l}^{res})^{T} \times x_l + h_{l}^{out} \otimes H_{t}^{post}

参数说明

参数名 输入/输出 描述 数据类型 数据格式
x 输入 待计算的张量,表示网络中mHC层的输入数据。 FLOAT16、BFLOAT16 ND
h_res 输入 mHC的h_res变换矩阵,是做完sinkhorn变换后的双随机矩阵。 FLOAT32 ND
h_out 输入 Atten/MLP层的输出。 FLOAT16、BFLOAT16 ND
h_post 输入 mHC的h_post变换矩阵。 FLOAT32 ND
out 输出 网络中mHC层的输出数据,作为下一层的输入。 FLOAT16、BFLOAT16 ND

约束说明

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_mhc_post 通过aclnnMhcPost接口方式调用MhcPost算子。