mamba2_chunk_state 算子说明
功能和实现说明
mamba2_chunk_state 用于在 MambaV2 Prefill 阶段进行 chunk 内的离散时间状态更新,根据 chunk_cumsum 得到的累积量dacs/dacs_chunk和状态更新因子dtout进行状态递推,输出 chunk 内每一步的状态序列,并生成用于下一 chunk 的最终隐藏状态。本算子实现为 Vector+cube 融合算子,支持 FP16/FP32。
计算流

Kernel输入输出(I/O)
输入
| Tensor | shape | dtype |
|---|---|---|
| dtout | BCLH | FP32 |
| dacs | BCLH | FP32 |
| bt | BCLGN | FP16 |
| xt | BCLHP | FP16 |
输出
| Tensor | shape | dtype |
|---|---|---|
| states | BCHNP | FP32 |
参数说明:
B: batch size
C: number of chunks
L: chunk size
H: number of head
G: ngroups
N: state size
P: head dim
其中C*L为padding后的序列长度
调用方式
import npu_ops_transformer_ext
out = torch.ops.npu_ops_transformer_ext.mamba2_chunk_state(dtout, dacs, bt, xt)
测试方法
见当前目录 tests/
python test_chunk_state.py