mamba2_rmsnormgated 算子说明
功能和实现说明
基于状态空间模型(SSM)的因果卷积,实现 MambaV2 Prefill 阶段的因果卷积计算。计算流程包含 kernel_size=4 的 depthwise conv1d 和 SiLU 激活。本算子采用纯 Vector 实现 conv1d,并融合 bias 和 SiLU 运算以提升性能。
计算流

自定义Kernel输入输出(I/O)
输入
| Tensor | shape | dtype |
|---|---|---|
| x | BSD | FP32 |
| w | D | FP32 |
| z | BSD | FP32 |
输出
| Tensor | shape | dtype |
|---|---|---|
| out | BSD | FP32 |
参数说明:
B: batch size
S: sequence len
D: dimension
额外需要参数
G: ngroups
E: eps
调用方式
import npu_ops_transformer_ext
out = torch.ops.npu_ops_transformer_ext.mamba2_rmsnormgated(x, z, w, G, E)
测试方法
见当前目录 tests/
python test_rmsnormgated.py