FusedCausalConv1d
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:对序列执行因果一维卷积,沿序列维度使用缓存数据(长度为卷积核宽减1)对各序列头部进行padding,确保输出依赖当前及历史输入;卷积完成后,将当前序列部分数据更新到缓存;在因果一维卷积输出的基础上,将原始输入加到输出上以实现残差连接。支持 APC(Automatic Prefix Caching)、MTP(投机解码)、残差连接等特性。
-
本算子支持以下场景:
-
场景一(prefill场景):
x: [cu_seq_len, dim] weight: [K, dim],其中K=3 conv_states: [-1, K-1, dim] query_start_loc: [batch+1] cache_indices: 不开APC:[batch]或None, 开APC:[block, maxNumBlocks] initial_state_mode: [batch] bias: [dim](无作用) num_accepted_tokens: [batch](无作用) num_computed_tokens: [batch] block_idx_first_scheduled_token: 不开APC:None, 开APC:[batch] block_idx_last_scheduled_token: 不开APC:None, 开APC:[batch] initial_state_idx: 不开APC:None, 开APC:[batch] activation_mode: (无作用) pad_slot_id: 默认值 -1 run_mode: (无作用) max_query_len:默认值 1 residual_connection: 不做残差: 0, 做残差:1 block_size: 典型值 128/256 conv_mode:Qwen3-Next模式: 0, Pangu V2: 1 y: [cu_seq_len, dim]其中cu_seq_len为batch内所有变长序列拼接后的总长度。
-
场景二(prefill和decode混合场景):
x: [cu_seq_len, dim] weight: [K, dim],其中K=3 conv_states: [-1, K-1+m, dim] query_start_loc: [batch+1] cache_indices: 不开APC:[batch]或None, 开APC:[block, maxNumBlocks] initial_state_mode: [batch] bias: [dim](无作用) num_accepted_tokens: [batch] num_computed_tokens: [batch] block_idx_first_scheduled_token: 不开APC:None, 开APC:[batch] block_idx_last_scheduled_token: 不开APC:None, 开APC:[batch] initial_state_idx: 不开APC:None, 开APC:[batch] activation_mode: (无作用) pad_slot_id: 默认值 -1 run_mode: (无作用) max_query_len:默认值 1 residual_connection: 不做残差: 0, 做残差:1 block_size: 典型值 128/256 conv_mode:Qwen3-Next模式: 0, Pangu V2: 1 y: [cu_seq_len, dim]其中cu_seq_len为batch内所有变长序列拼接后的总长度。
-
场景三(decode场景 - 变长序列):
x: [cu_seq_len, dim] weight: [K, dim],其中K=3 conv_states: [-1, K-1+m, dim] query_start_loc: [batch+1] cache_indices: 不开APC:[batch]或None, 开APC:[block, maxNumBlocks] initial_state_mode: [batch] bias: [dim](无作用) num_accepted_tokens: [batch] num_computed_tokens: [batch] block_idx_first_scheduled_token: 不开APC:None, 开APC:[batch] block_idx_last_scheduled_token: 不开APC:None, 开APC:[batch] initial_state_idx: 不开APC:None, 开APC:[batch] activation_mode: (无作用) pad_slot_id: 默认值 -1 run_mode: (无作用) max_query_len:默认值 1 residual_connection: 不做残差: 0, 做残差:1 block_size: 典型值 128/256 conv_mode:Qwen3-Next模式: 0, Pangu V2: 1 y: [cu_seq_len, dim]其中state_len必须大于所有batch中最大的token个数加1。
-
场景四(decode场景 - 固定batch):
x: [batch, m+1, dim] weight: [K, dim],其中K=3 conv_states: [-1, K-1+m, dim] query_start_loc: [batch+1] cache_indices: 不开APC:[batch]或None, 开APC:[block, maxNumBlocks] initial_state_mode: [batch] bias: [dim](无作用) num_accepted_tokens: [batch] num_computed_tokens: [batch] block_idx_first_scheduled_token: 不开APC:None, 开APC:[batch] block_idx_last_scheduled_token: 不开APC:None, 开APC:[batch] initial_state_idx: 不开APC:None, 开APC:[batch] activation_mode: (无作用) pad_slot_id: 默认值 -1 run_mode: (无作用) max_query_len:默认值 1 residual_connection: 不做残差: 0, 做残差:1 block_size: 典型值 128/256 conv_mode:Qwen3-Next模式: 0, Pangu V2: 1 y: [batch, m+1, dim]
-
-
计算公式:
K是卷积核宽度(固定为3),L是原始序列长度,dim是特征维度。
-
缓存读取
缓存行索引:
readCacheLine={cacheIndices[batchId, initialStateIdx[batchId]],APC 模式cacheIndices[batchId],非 APC 且 cacheIndices 存在batchId,其他readCacheLine = \begin{cases} cacheIndices[batchId, \; initialStateIdx[batchId]], & \text{APC 模式} \\ cacheIndices[batchId], & \text{非 APC 且 cacheIndices 存在} \\ batchId, & \text{其他} \end{cases}
Case 1:首次计算(numComputedTokens[batchId] == 0)
cachedState[i,dim]=0,0≤i<K−1cachedState[i, dim] = 0, \quad 0 \leq i < K-1
offset=0offset = 0
Case 2:投机解码模式(numAcceptedTokens 存在)
offset=numAcceptedTokens[batchId]−1offset = numAcceptedTokens[batchId] - 1
cachedState[i,dim]=convStates[readCacheLine][i,dim],0≤i<offset+K−1cachedState[i, dim] = convStates[readCacheLine][i, dim], \quad 0 \leq i < offset + K - 1
Case 3:默认模式
offset=C−(K−1)offset = C - (K - 1)
cachedState[i,dim]=convStates[readCacheLine][i,dim],0≤i<offset+K−1cachedState[i, dim] = convStates[readCacheLine][i, dim], \quad 0 \leq i < offset + K - 1
-
缓存拼接
paddedInput[i,dim]={cachedState[i,dim],0≤i<offset+K−1x[i−(offset+K−1),dim],offset+K−1≤i<offset+K−1+LpaddedInput[i, dim] = \begin{cases} cachedState[i, dim], & 0 \leq i < offset + K - 1 \\ x[i - (offset + K - 1), dim], & offset + K - 1 \leq i < offset + K - 1 + L \end{cases}
-
缓存更新
Len=offset+K−1+LLen = offset + K - 1 + L
M=min(C, Len)M = \min(C, \; Len)
writeCacheLine={cacheIndices[batchId, idxLast],APC 模式cacheIndices[batchId],非 APC 且 cacheIndices 存在batchId,其他writeCacheLine = \begin{cases} cacheIndices[batchId, \; idxLast], & \text{APC 模式} \\ cacheIndices[batchId], & \text{非 APC 且 cacheIndices 存在} \\ batchId, & \text{其他} \end{cases}
convStates[writeCacheLine][C−M+i,dim]=paddedInput[Len−M+i,dim],i=0,1,…,M−1convStates[writeCacheLine][C - M + i, dim] = paddedInput[Len - M + i, dim], \quad i = 0, 1, \dots, M-1
-
Offset 裁剪
x′[i,dim]=paddedInput[i+offset,dim],0≤i<K−1+Lx'[i, dim] = paddedInput[i + offset, dim], \quad 0 \leq i < K - 1 + L
-
APC 缓存填充(可选,APC 模式下)
seqCompletedOffsetToken=numComputedTokens[batchId]mod BseqCompletedOffsetToken = numComputedTokens[batchId] \mod B
seqCompletedOffset=B−seqCompletedOffsetTokenseqCompletedOffset = B - seqCompletedOffsetToken
seqEndOffset=(L−seqCompletedOffset)mod BseqEndOffset = (L - seqCompletedOffset) \mod B
lastFullBlockTokenIndex={L−seqEndOffset−B,seqEndOffset=0L−seqEndOffset,otherwiselastFullBlockTokenIndex = \begin{cases} L - seqEndOffset - B, & seqEndOffset = 0 \\ L - seqEndOffset, & \text{otherwise} \end{cases}
nBlockToFill=idxLast−idxFirstnBlockToFill = idxLast - idxFirst
对每个 chunk = 0, 1, ..., nBlockToFill - 1:
boundaryIdx=lastFullBlockTokenIndex−(nBlockToFill−chunk−1)×BboundaryIdx = lastFullBlockTokenIndex - (nBlockToFill - chunk - 1) \times B
convStates[cacheIndices[batchId, idxFirst+chunk]][C−(K−1)+j, dim]=x′[boundaryIdx+j, dim],j=0,…,K−2convStates[cacheIndices[batchId, \; idxFirst + chunk]][C-(K-1)+j, \; dim] = x'[boundaryIdx + j, \; dim], \quad j = 0, \dots, K-2
-
因果1维卷积
y[i,dim]=∑k=0K−1w[k,dim]⋅x′[i+k,dim],i=0,1,…,L−1y[i, dim] = \sum_{k=0}^{K-1} w[k, dim] \cdot x'[i + k, dim], \quad i = 0, 1, \dots, L-1
-
零填充重置(可选,当convMode == 1 并且 numComputedTokens不为空时)
resetIdx=min (max (K−1−numComputedTokens[batchId], 0), L)resetIdx = \min\!\Big(\max\!\big(K - 1 - numComputedTokens[batchId], \; 0\big), \; L\Big)
y[i,dim]=0,0≤i<resetIdxy[i, dim] = 0, \quad 0 \leq i < resetIdx
-
残差连接(可选)
y[i,dim]=x[i,dim]+y[i,dim]y[i, dim] = x[i, dim] + y[i, dim]
-
参数说明
| 参数名 | 输入/输出 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x | 输入/输出 | 公式中的输入序列x。 | BFLOAT16、FLOAT16 | ND |
| weight | 输入 | 公式中的因果1维卷积核w。 | 同 x | ND |
| conv_states | 输入/输出 |
|
同 x | ND |
| query_start_loc | 可选输入 |
|
INT32 | ND |
| cache_indices | 可选输入 | 缓存索引,指定每个序列对应的缓存状态在 cacheState 中的索引。 | INT32 | ND |
| initial_state_mode | 可选输入 | 制定每个序列对应的 padding 策略。 | INT32 | ND |
| bias | 可选输入 | 卷积的偏置。 | 同 x | ND |
| num_accepted_tokens | 可选输入 | 公式中的numAcceptedTokens。 | INT32 | ND |
| num_computed_tokens | 可选输入 | 公式中的numComputedTokens,当前 batch 已经处理的 token 总数,用于判断初始状态。 | INT32 | ND |
| block_idx_first_scheduled_token | 可选输入 | 当前 batch 的第一个 token 对应的 block 索引。 | INT32 | ND |
| block_idx_last_scheduled_token | 可选输入 | 当前 batch 的最后一个 token 对应的 block 索引。 | INT32 | ND |
| initial_state_idx | 可选输入 | 初始索引块的索引。 | INT32 | ND |
| activation_mode | 可选输入 | 激活函数类型。 | STR | - |
| pad_slot_id | 可选输入 | 用于跳过不需要参与计算的变长序列。 | INT64 | - |
| run_mode | 可选输入 | 表示 prefill 或者 decode 场景。历史遗留接口,暂不支持此字段。 | INT64 | - |
| max_query_len | 可选输入 | 所有 batch 中的最大 seq_len,支持为-1。 | INT64 | - |
| residual_connection | 可选输入 | 用于残差连接。 | INT64 | - |
| block_size | 可选输入 | block 块的大小。 | INT64 | - |
| conv_mode | 可选输入 | 公式中的convMode,支持 Qwen3-Next 和 Pangu V2 两种实现。 | INT64 | - |
| y | 输出 | x 经过conv1d 计算后的结果。 | 同 x | ND |
约束说明
-
输入shape限制:
- prefill场景:
- x支持2维[cu_seq_len, dim]。
- weight必须是2维[K, dim],其中K固定为3。
- conv_states必须是3维[..., K-1, dim],第0维大小不固定且大于等于batch。
- cache_indices为1维[batch, ]或2维[batch, maxNumBlocks],其中1维表示未开启APC,2维表示开启APC。
- cu_seq_len范围[batch, 1024 * 1024],dim范围[64, 16384]且是16的倍数,batch范围[1, 256],maxNumBlocks范围[1, 1024]。
- prefill和decode混合场景:
- x支持2维[cu_seq_len, dim]。
- weight必须是2维[K, dim],其中K固定为3。
- conv_states必须是3维[..., K-1+m, dim],第0维大小不固定且大于等于batch。
- cache_indices为1维[batch, ]或2维[batch, maxNumBlocks],其中1维表示未开启APC,2维表示开启APC。
- cu_seq_len范围[batch, 1024 * 1024],dim范围[64, 16384]且是16的倍数,batch范围[1, 256],maxNumBlocks范围[1, 1024]。
- decode场景(固定batch):
- x支持3维[batch, m+1, dim]。
- weight必须是2维[K, dim],其中K固定为3。
- conv_states必须是3维[..., K-1+m, dim],第0维大小不固定且大于等于batch。
- cache_indices为1维[batch, ]或2维[batch, maxNumBlocks],其中1维表示未开启APC,2维表示开启APC。
- m范围[0, 7],dim范围[64, 16384]且是16的倍数,batch范围[1, 256],maxNumBlocks范围[1, 1024]。
- decode场景(变长序列):
- x支持2维[cu_seq_len, dim]。
- weight必须是2维[K, dim],其中K固定为3。
- conv_states必须是3维[..., k-1+m, dim],第0维大小不固定且大于等于batch。
- cache_indices为1维[batch, ]或2维[batch, maxNumBlocks],其中1维表示未开启APC,2维表示开启APC。
- cu_seq_len范围[batch, batch*8],每个batch的token个数范围为[1, 8]。dim范围[64, 16384]且是16的倍数,batch范围[1, 256],maxNumBlocks范围[1, 1024]。
- prefill场景:
-
输入值域限制:
- query_start_loc是累计偏移量,取值范围[0, cu_seq_len],长度为batch+1,query_start_loc[i]表示第i个序列的起始偏移,query_start_loc[batch+1]表示最后一个序列的结束位置。
- blockSize 必须大于等于 2。
- APC 开启时,必须提供 blockIdxFirstScheduledToken、blockIdxLastScheduledToken 及 initialStateIdx,且满足如下需求,i为batch的索引: - initialStateIdx[i] <= blockIdxFirstScheduledToken[i]+1 - initialStateIdx[i] <= blockIdxLastScheduledToken[i] - blockIdxFirstScheduledToken[i] <= blockIdxLastScheduledToken[i]
- num_accepted_tokens分为None和非None,非None情况下长度为batch,每个元素取值不超过当前batch的token个数且大于0。
- Pangu V2 模式(conv_mode = 1)下,首次运行 numComputedTokens 不能为 None。
- 算子入参与中间计算结果,在对应运行数据类型(float16/bfloat16) 下,数值均不会超出该类型值域范围。
- 算子输入不支持有±inf和nan的情况。
调用说明
调用方式 样例代码 说明 aclnn接口 test_aclnn_fused_causal_conv1d 通过aclnnFusedCausalConv1d调用FusedCausalConv1d算子 图模式 - 通过算子IR构图方式调用FusedCausalConv1d算子