FusedCausalConv1d

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:对序列执行因果一维卷积,沿序列维度使用缓存数据(长度为卷积核宽减1)对各序列头部进行padding,确保输出依赖当前及历史输入;卷积完成后,将当前序列部分数据更新到缓存;在因果一维卷积输出的基础上,将原始输入加到输出上以实现残差连接。支持 APC(Automatic Prefix Caching)、MTP(投机解码)、残差连接等特性。

  • 本算子支持以下场景:

    • 场景一(prefill场景):

      x: [cu_seq_len, dim]
      weight: [K, dim],其中K=3
      conv_states: [-1, K-1, dim]
      query_start_loc: [batch+1]
      cache_indices: 不开APC:[batch]或None, 开APC:[block, maxNumBlocks]
      initial_state_mode: [batch]
      bias: [dim](无作用)
      num_accepted_tokens: [batch](无作用)
      num_computed_tokens: [batch]
      block_idx_first_scheduled_token: 不开APC:None, 开APC:[batch]
      block_idx_last_scheduled_token: 不开APC:None, 开APC:[batch]
      initial_state_idx: 不开APC:None, 开APC:[batch]
      activation_mode: (无作用)
      pad_slot_id: 默认值 -1
      run_mode: (无作用)
      max_query_len:默认值 1
      residual_connection: 不做残差: 0, 做残差:1
      block_size: 典型值 128/256
      conv_mode:Qwen3-Next模式: 0, Pangu V2: 1
      y: [cu_seq_len, dim]
      

      其中cu_seq_len为batch内所有变长序列拼接后的总长度。

    • 场景二(prefill和decode混合场景):

      x: [cu_seq_len, dim]
      weight: [K, dim],其中K=3
      conv_states: [-1, K-1+m, dim]
      query_start_loc: [batch+1]
      cache_indices: 不开APC:[batch]或None, 开APC:[block, maxNumBlocks]
      initial_state_mode: [batch]
      bias: [dim](无作用)
      num_accepted_tokens: [batch]
      num_computed_tokens: [batch]
      block_idx_first_scheduled_token: 不开APC:None, 开APC:[batch]
      block_idx_last_scheduled_token: 不开APC:None, 开APC:[batch]
      initial_state_idx: 不开APC:None, 开APC:[batch]
      activation_mode: (无作用)
      pad_slot_id: 默认值 -1
      run_mode: (无作用)
      max_query_len:默认值 1
      residual_connection: 不做残差: 0, 做残差:1
      block_size: 典型值 128/256
      conv_mode:Qwen3-Next模式: 0, Pangu V2: 1
      y: [cu_seq_len, dim]
      

      其中cu_seq_len为batch内所有变长序列拼接后的总长度。

    • 场景三(decode场景 - 变长序列):

      x: [cu_seq_len, dim]
      weight: [K, dim],其中K=3
      conv_states: [-1, K-1+m, dim]
      query_start_loc: [batch+1]
      cache_indices: 不开APC:[batch]或None, 开APC:[block, maxNumBlocks]
      initial_state_mode: [batch]
      bias: [dim](无作用)
      num_accepted_tokens: [batch]
      num_computed_tokens: [batch]
      block_idx_first_scheduled_token: 不开APC:None, 开APC:[batch]
      block_idx_last_scheduled_token: 不开APC:None, 开APC:[batch]
      initial_state_idx: 不开APC:None, 开APC:[batch]
      activation_mode: (无作用)
      pad_slot_id: 默认值 -1
      run_mode: (无作用)
      max_query_len:默认值 1
      residual_connection: 不做残差: 0, 做残差:1
      block_size: 典型值 128/256
      conv_mode:Qwen3-Next模式: 0, Pangu V2: 1
      y: [cu_seq_len, dim]
      

      其中state_len必须大于所有batch中最大的token个数加1。

    • 场景四(decode场景 - 固定batch):

      x: [batch, m+1, dim]
      weight: [K, dim],其中K=3
      conv_states: [-1, K-1+m, dim]
      query_start_loc: [batch+1]
      cache_indices: 不开APC:[batch]或None, 开APC:[block, maxNumBlocks]
      initial_state_mode: [batch]
      bias: [dim](无作用)
      num_accepted_tokens: [batch]
      num_computed_tokens: [batch]
      block_idx_first_scheduled_token: 不开APC:None, 开APC:[batch]
      block_idx_last_scheduled_token: 不开APC:None, 开APC:[batch]
      initial_state_idx: 不开APC:None, 开APC:[batch]
      activation_mode: (无作用)
      pad_slot_id: 默认值 -1
      run_mode: (无作用)
      max_query_len:默认值 1
      residual_connection: 不做残差: 0, 做残差:1
      block_size: 典型值 128/256
      conv_mode:Qwen3-Next模式: 0, Pangu V2: 1
      y: [batch, m+1, dim]
      
  • 计算公式:

    K是卷积核宽度(固定为3),L是原始序列长度,dim是特征维度。

    1. 缓存读取

      缓存行索引:

      readCacheLine={cacheIndices[batchId,  initialStateIdx[batchId]],APC 模式cacheIndices[batchId],非 APC 且 cacheIndices 存在batchId,其他readCacheLine = \begin{cases} cacheIndices[batchId, \; initialStateIdx[batchId]], & \text{APC 模式} \\ cacheIndices[batchId], & \text{非 APC 且 cacheIndices 存在} \\ batchId, & \text{其他} \end{cases}

      Case 1:首次计算(numComputedTokens[batchId] == 0)

      cachedState[i,dim]=0,0≤i<K−1cachedState[i, dim] = 0, \quad 0 \leq i < K-1

      offset=0offset = 0

      Case 2:投机解码模式(numAcceptedTokens 存在)

      offset=numAcceptedTokens[batchId]−1offset = numAcceptedTokens[batchId] - 1

      cachedState[i,dim]=convStates[readCacheLine][i,dim],0≤i<offset+K−1cachedState[i, dim] = convStates[readCacheLine][i, dim], \quad 0 \leq i < offset + K - 1

      Case 3:默认模式

      offset=C−(K−1)offset = C - (K - 1)

      cachedState[i,dim]=convStates[readCacheLine][i,dim],0≤i<offset+K−1cachedState[i, dim] = convStates[readCacheLine][i, dim], \quad 0 \leq i < offset + K - 1

    2. 缓存拼接

      paddedInput[i,dim]={cachedState[i,dim],0≤i<offset+K−1x[i−(offset+K−1),dim],offset+K−1≤i<offset+K−1+LpaddedInput[i, dim] = \begin{cases} cachedState[i, dim], & 0 \leq i < offset + K - 1 \\ x[i - (offset + K - 1), dim], & offset + K - 1 \leq i < offset + K - 1 + L \end{cases}

    3. 缓存更新

      Len=offset+K−1+LLen = offset + K - 1 + L

      M=min⁡(C,  Len)M = \min(C, \; Len)

      writeCacheLine={cacheIndices[batchId,  idxLast],APC 模式cacheIndices[batchId],非 APC 且 cacheIndices 存在batchId,其他writeCacheLine = \begin{cases} cacheIndices[batchId, \; idxLast], & \text{APC 模式} \\ cacheIndices[batchId], & \text{非 APC 且 cacheIndices 存在} \\ batchId, & \text{其他} \end{cases}

      convStates[writeCacheLine][C−M+i,dim]=paddedInput[Len−M+i,dim],i=0,1,…,M−1convStates[writeCacheLine][C - M + i, dim] = paddedInput[Len - M + i, dim], \quad i = 0, 1, \dots, M-1

    4. Offset 裁剪

      x′[i,dim]=paddedInput[i+offset,dim],0≤i<K−1+Lx'[i, dim] = paddedInput[i + offset, dim], \quad 0 \leq i < K - 1 + L

    5. APC 缓存填充(可选,APC 模式下)

      seqCompletedOffsetToken=numComputedTokens[batchId]mod  BseqCompletedOffsetToken = numComputedTokens[batchId] \mod B

      seqCompletedOffset=B−seqCompletedOffsetTokenseqCompletedOffset = B - seqCompletedOffsetToken

      seqEndOffset=(L−seqCompletedOffset)mod  BseqEndOffset = (L - seqCompletedOffset) \mod B

      lastFullBlockTokenIndex={L−seqEndOffset−B,seqEndOffset=0L−seqEndOffset,otherwiselastFullBlockTokenIndex = \begin{cases} L - seqEndOffset - B, & seqEndOffset = 0 \\ L - seqEndOffset, & \text{otherwise} \end{cases}

      nBlockToFill=idxLast−idxFirstnBlockToFill = idxLast - idxFirst

      对每个 chunk = 0, 1, ..., nBlockToFill - 1:

      boundaryIdx=lastFullBlockTokenIndex−(nBlockToFill−chunk−1)×BboundaryIdx = lastFullBlockTokenIndex - (nBlockToFill - chunk - 1) \times B

      convStates[cacheIndices[batchId,  idxFirst+chunk]][C−(K−1)+j,  dim]=x′[boundaryIdx+j,  dim],j=0,…,K−2convStates[cacheIndices[batchId, \; idxFirst + chunk]][C-(K-1)+j, \; dim] = x'[boundaryIdx + j, \; dim], \quad j = 0, \dots, K-2

    6. 因果1维卷积

      y[i,dim]=∑k=0K−1w[k,dim]⋅x′[i+k,dim],i=0,1,…,L−1y[i, dim] = \sum_{k=0}^{K-1} w[k, dim] \cdot x'[i + k, dim], \quad i = 0, 1, \dots, L-1

    7. 零填充重置(可选,当convMode == 1 并且 numComputedTokens不为空时)

      resetIdx=min⁡ ⁣(max⁡ ⁣(K−1−numComputedTokens[batchId],  0),  L)resetIdx = \min\!\Big(\max\!\big(K - 1 - numComputedTokens[batchId], \; 0\big), \; L\Big)

      y[i,dim]=0,0≤i<resetIdxy[i, dim] = 0, \quad 0 \leq i < resetIdx

    8. 残差连接(可选)

      y[i,dim]=x[i,dim]+y[i,dim]y[i, dim] = x[i, dim] + y[i, dim]

参数说明

参数名 输入/输出 描述 数据类型 数据格式
x 输入/输出 公式中的输入序列x。 BFLOAT16、FLOAT16 ND
weight 输入 公式中的因果1维卷积核w。 同 x ND
conv_states 输入/输出
  • 公式中的convStates,缓存状态张量,存储各序列的历史 token 数据。
  • 各序列计算完成后原地更新。
同 x ND
query_start_loc 可选输入
  • x为二维场景下,序列起始位置索引,记录各序列在拼接张量 x 中的起始位置。
  • queryStartLoc[i] 表示第 i 个序列的起始偏移。queryStartLoc[0]必须为0,queryStartLoc[-1]必须为cu_seq_len,相邻两个数据不相等。
INT32 ND
cache_indices 可选输入 缓存索引,指定每个序列对应的缓存状态在 cacheState 中的索引。 INT32 ND
initial_state_mode 可选输入 制定每个序列对应的 padding 策略。 INT32 ND
bias 可选输入 卷积的偏置。 同 x ND
num_accepted_tokens 可选输入 公式中的numAcceptedTokens。 INT32 ND
num_computed_tokens 可选输入 公式中的numComputedTokens,当前 batch 已经处理的 token 总数,用于判断初始状态。 INT32 ND
block_idx_first_scheduled_token 可选输入 当前 batch 的第一个 token 对应的 block 索引。 INT32 ND
block_idx_last_scheduled_token 可选输入 当前 batch 的最后一个 token 对应的 block 索引。 INT32 ND
initial_state_idx 可选输入 初始索引块的索引。 INT32 ND
activation_mode 可选输入 激活函数类型。 STR -
pad_slot_id 可选输入 用于跳过不需要参与计算的变长序列。 INT64 -
run_mode 可选输入 表示 prefill 或者 decode 场景。历史遗留接口,暂不支持此字段。 INT64 -
max_query_len 可选输入 所有 batch 中的最大 seq_len,支持为-1。 INT64 -
residual_connection 可选输入 用于残差连接。 INT64 -
block_size 可选输入 block 块的大小。 INT64 -
conv_mode 可选输入 公式中的convMode,支持 Qwen3-Next 和 Pangu V2 两种实现。 INT64 -
y 输出 x 经过conv1d 计算后的结果。 同 x ND

约束说明

  • 输入shape限制:

    • prefill场景:
      • x支持2维[cu_seq_len, dim]。
      • weight必须是2维[K, dim],其中K固定为3。
      • conv_states必须是3维[..., K-1, dim],第0维大小不固定且大于等于batch。
      • cache_indices为1维[batch, ]或2维[batch, maxNumBlocks],其中1维表示未开启APC,2维表示开启APC。
      • cu_seq_len范围[batch, 1024 * 1024],dim范围[64, 16384]且是16的倍数,batch范围[1, 256],maxNumBlocks范围[1, 1024]。
    • prefill和decode混合场景:
      • x支持2维[cu_seq_len, dim]。
      • weight必须是2维[K, dim],其中K固定为3。
      • conv_states必须是3维[..., K-1+m, dim],第0维大小不固定且大于等于batch。
      • cache_indices为1维[batch, ]或2维[batch, maxNumBlocks],其中1维表示未开启APC,2维表示开启APC。
      • cu_seq_len范围[batch, 1024 * 1024],dim范围[64, 16384]且是16的倍数,batch范围[1, 256],maxNumBlocks范围[1, 1024]。
    • decode场景(固定batch):
      • x支持3维[batch, m+1, dim]。
      • weight必须是2维[K, dim],其中K固定为3。
      • conv_states必须是3维[..., K-1+m, dim],第0维大小不固定且大于等于batch。
      • cache_indices为1维[batch, ]或2维[batch, maxNumBlocks],其中1维表示未开启APC,2维表示开启APC。
      • m范围[0, 7],dim范围[64, 16384]且是16的倍数,batch范围[1, 256],maxNumBlocks范围[1, 1024]。
    • decode场景(变长序列):
      • x支持2维[cu_seq_len, dim]。
      • weight必须是2维[K, dim],其中K固定为3。
      • conv_states必须是3维[..., k-1+m, dim],第0维大小不固定且大于等于batch。
      • cache_indices为1维[batch, ]或2维[batch, maxNumBlocks],其中1维表示未开启APC,2维表示开启APC。
      • cu_seq_len范围[batch, batch*8],每个batch的token个数范围为[1, 8]。dim范围[64, 16384]且是16的倍数,batch范围[1, 256],maxNumBlocks范围[1, 1024]。
  • 输入值域限制:

    • query_start_loc是累计偏移量,取值范围[0, cu_seq_len],长度为batch+1,query_start_loc[i]表示第i个序列的起始偏移,query_start_loc[batch+1]表示最后一个序列的结束位置。
    • blockSize 必须大于等于 2。
    • APC 开启时,必须提供 blockIdxFirstScheduledToken、blockIdxLastScheduledToken 及 initialStateIdx,且满足如下需求,i为batch的索引: - initialStateIdx[i] <= blockIdxFirstScheduledToken[i]+1 - initialStateIdx[i] <= blockIdxLastScheduledToken[i] - blockIdxFirstScheduledToken[i] <= blockIdxLastScheduledToken[i]
    • num_accepted_tokens分为None和非None,非None情况下长度为batch,每个元素取值不超过当前batch的token个数且大于0。
    • Pangu V2 模式(conv_mode = 1)下,首次运行 numComputedTokens 不能为 None。
    • 算子入参与中间计算结果,在对应运行数据类型(float16/bfloat16) 下,数值均不会超出该类型值域范围。
    • 算子输入不支持有±inf和nan的情况。

    调用说明

    调用方式 样例代码 说明
    aclnn接口 test_aclnn_fused_causal_conv1d 通过aclnnFusedCausalConv1d调用FusedCausalConv1d算子
    图模式 - 通过算子IR构图方式调用FusedCausalConv1d算子