mamba2_causal_conv1d 算子说明
功能和实现说明
基于状态空间模型(SSM)的因果卷积,实现 MambaV2 Prefill 阶段的因果卷积计算。计算流程包含 kernel_size=4 的 depthwise conv1d 和 SiLU 激活。本算子采用纯 Vector 实现 conv1d,并融合 bias 和 SiLU 运算以提升性能。
计算流

自定义Kernel输入输出(I/O)
输入
| Tensor | shape | dtype |
|---|---|---|
| x | BDS | FP32 |
| w | BDS | FP32 |
| b | D | FP16 |
输出
| Tensor | shape | dtype |
|---|---|---|
| out | BDS | FP32 |
参数说明:
B: batch size
D: dimension
S: sequence len
该算子支持任意长度S
调用方式
import npu_ops_transformer_ext
out = torch.ops.npu_ops_transformer_ext.mamba2_causal_conv1d(x, w, b)
测试方法
见当前目录 tests/
python test_causal_conv1d.py