mamba2_chunk_cumsum 算子说明

功能和实现说明

mamba2_chunk_cumsum 用于在 MambaV2 Prefill 阶段对 chunk 内部执行按时间步的累积求和操作,实现 SSM 中状态量在 chunk 维度上的递推更新。算子对输入序列在 S 维度按 chunk_size 拆分,并在每个 chunk 内按照因果顺序执行 cumulative sum,用于后续 chunk 状态更新与 selective scan 计算。本算子基于 Vector 实现累积求和计算,支持 FP16/FP32 输入输出。

计算流

Kernel输入输出(I/O)

输入

Tensor shape dtype
at H FP32
dt BCLH FP16
dt_bias H FP16
dt_mask BCLH FP16

输出

Tensor shape dtype
dtout BCLH FP32
dacs BCLH FP32
dacs_chunk BCH FP32

参数说明:

B: batch size
C: number of chunks
L: chunk size
H: number of head
其中C*L为padding后的序列长度

调用方式

import npu_ops_transformer_ext

out = torch.ops.npu_ops_transformer_ext.mamba2_chunk_cumsum(at, dt, dt_bias, dt_mask)

测试方法

见当前目录 tests/

python test_chunk_cumsum.py