aclnnFusedCausalConv1d

📄 查看源码

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 接口功能:对序列执行因果一维卷积,沿序列维度使用缓存数据(长度为卷积核宽减1)对各序列头部进行padding,确保输出依赖当前及历史输入;卷积完成后,将当前序列部分数据更新到缓存;在因果一维卷积输出的基础上,将原始输入加到输出上以实现残差连接。支持 APC(Automatic Prefix Caching)、MTP(投机解码)、残差连接等特性。相较于标准 causal_conv1d 算子,本算子新增 APC 缓存复用、PD混部、残差连接可选等功能。

  • 支持以下场景:

    • 场景一(prefill场景):

      x: [cu_seq_len, dim]
      weight: [K, dim],其中K=3
      conv_states: [-1, K-1, dim]
      query_start_loc: [batch+1]
      cache_indices: 不开APC:[batch], 开APC:[block, maxNumBlocks]
      initial_state_mode: [batch]
      bias: [dim](无作用)
      num_accepted_tokens: [batch](无作用)
      num_computed_tokens: [batch]
      block_idx_first_scheduled_token: 不开APC:None, 开APC:[batch]
      block_idx_last_scheduled_token: 不开APC:None, 开APC:[batch]
      initial_state_idx: 不开APC:None, 开APC:[batch]
      activation_mode: (无作用)
      pad_slot_id: 默认值 -1
      run_mode: (无作用)
      max_query_len:默认值 1
      residual_connection: 不做残差: 0, 做残差:1
      block_size: 典型值 128/256
      conv_mode:Qwen3-Next模式: 0, Pangu V2: 1
      y: [cu_seq_len, dim]
      

      其中cu_seq_len为batch内所有变长序列拼接后的总长度。

    • 场景二(prefill和decode混合场景):

      x: [cu_seq_len, dim]
      weight: [K, dim],其中K=3
      conv_states: [-1, K-1+m, dim]
      query_start_loc: [batch+1]
      cache_indices: 不开APC:[batch], 开APC:[block, maxNumBlocks]
      initial_state_mode: [batch]
      bias: [dim](无作用)
      num_accepted_tokens: [batch]
      num_computed_tokens: [batch]
      block_idx_first_scheduled_token: 不开APC:None, 开APC:[batch]
      block_idx_last_scheduled_token: 不开APC:None, 开APC:[batch]
      initial_state_idx: 不开APC:None, 开APC:[batch]
      activation_mode: (无作用)
      pad_slot_id: 默认值 -1
      run_mode: (无作用)
      max_query_len:默认值 1
      residual_connection: 不做残差: 0, 做残差:1
      block_size: 典型值 128/256
      conv_mode:Qwen3-Next模式: 0, Pangu V2: 1
      y: [cu_seq_len, dim]
      

      其中cu_seq_len为batch内所有变长序列拼接后的总长度。

    • 场景三(decode场景 - 变长序列):

      x: [cu_seq_len, dim]
      weight: [K, dim],其中K=3
      conv_states: [-1, K-1+m, dim]
      query_start_loc: [batch+1]
      cache_indices: 不开APC:[batch], 开APC:[block, maxNumBlocks]
      initial_state_mode: [batch]
      bias: [dim](无作用)
      num_accepted_tokens: [batch]
      num_computed_tokens: [batch]
      block_idx_first_scheduled_token: 不开APC:None, 开APC:[batch]
      block_idx_last_scheduled_token: 不开APC:None, 开APC:[batch]
      initial_state_idx: 不开APC:None, 开APC:[batch]
      activation_mode: (无作用)
      pad_slot_id: 默认值 -1
      run_mode: (无作用)
      max_query_len:默认值 1
      residual_connection: 不做残差: 0, 做残差:1
      block_size: 典型值 128/256
      conv_mode:Qwen3-Next模式: 0, Pangu V2: 1
      y: [cu_seq_len, dim]
      

      其中state_len必须大于所有batch中最大的token个数加1。

    • 场景四(decode场景 - 固定batch):

      x: [batch, m+1, dim]
      weight: [K, dim],其中K=3
      conv_states: [-1, K-1+m, dim]
      query_start_loc: [batch+1]
      cache_indices: 不开APC:[batch], 开APC:[block, maxNumBlocks]
      initial_state_mode: [batch]
      bias: [dim](无作用)
      num_accepted_tokens: [batch]
      num_computed_tokens: [batch]
      block_idx_first_scheduled_token: 不开APC:None, 开APC:[batch]
      block_idx_last_scheduled_token: 不开APC:None, 开APC:[batch]
      initial_state_idx: 不开APC:None, 开APC:[batch]
      activation_mode: (无作用)
      pad_slot_id: 默认值 -1
      run_mode: (无作用)
      max_query_len:默认值 1
      residual_connection: 不做残差: 0, 做残差:1
      block_size: 典型值 128/256
      conv_mode:Qwen3-Next模式: 0, Pangu V2: 1
      y: [batch, m+1, dim]
      
  • 计算公式:

    K是卷积核宽度(固定为3),L是原始序列长度,dim是特征维度。

    1. 缓存读取

      缓存行索引:

      readCacheLine={cacheIndices[batchId,  initialStateIdx[batchId]],APC 模式cacheIndices[batchId],非 APC 且 cacheIndices 存在batchId,其他readCacheLine = \begin{cases} cacheIndices[batchId, \; initialStateIdx[batchId]], & \text{APC 模式} \\ cacheIndices[batchId], & \text{非 APC 且 cacheIndices 存在} \\ batchId, & \text{其他} \end{cases}

      Case 1:首次计算(numComputedTokens[batchId] == 0)

      cachedState[i,dim]=0,0≤i<K−1cachedState[i, dim] = 0, \quad 0 \leq i < K-1

      offset=0offset = 0

      Case 2:投机解码模式(numAcceptedTokens 存在)

      offset=numAcceptedTokens[batchId]−1offset = numAcceptedTokens[batchId] - 1

      cachedState[i,dim]=convStates[readCacheLine][i,dim],0≤i<offset+K−1cachedState[i, dim] = convStates[readCacheLine][i, dim], \quad 0 \leq i < offset + K - 1

      Case 3:默认模式

      offset=C−(K−1)offset = C - (K - 1)

      cachedState[i,dim]=convStates[readCacheLine][i,dim],0≤i<offset+K−1cachedState[i, dim] = convStates[readCacheLine][i, dim], \quad 0 \leq i < offset + K - 1

    2. 缓存拼接

      paddedInput[i,dim]={cachedState[i,dim],0≤i<offset+K−1x[i−(offset+K−1),dim],offset+K−1≤i<offset+K−1+LpaddedInput[i, dim] = \begin{cases} cachedState[i, dim], & 0 \leq i < offset + K - 1 \\ x[i - (offset + K - 1), dim], & offset + K - 1 \leq i < offset + K - 1 + L \end{cases}

    3. 缓存更新

      Len=offset+K−1+LLen = offset + K - 1 + L

      M=min⁡(C,  Len)M = \min(C, \; Len)

      writeCacheLine={cacheIndices[batchId,  idxLast],APC 模式cacheIndices[batchId],非 APC 且 cacheIndices 存在batchId,其他writeCacheLine = \begin{cases} cacheIndices[batchId, \; idxLast], & \text{APC 模式} \\ cacheIndices[batchId], & \text{非 APC 且 cacheIndices 存在} \\ batchId, & \text{其他} \end{cases}

      convStates[writeCacheLine][C−M+i,dim]=paddedInput[Len−M+i,dim],i=0,1,…,M−1convStates[writeCacheLine][C - M + i, dim] = paddedInput[Len - M + i, dim], \quad i = 0, 1, \dots, M-1

    4. Offset 裁剪

      x′[i,dim]=paddedInput[i+offset,dim],0≤i<K−1+Lx'[i, dim] = paddedInput[i + offset, dim], \quad 0 \leq i < K - 1 + L

    5. APC 缓存填充(可选,APC 模式下)

      seqCompletedOffsetToken=numComputedTokens[batchId]mod  BseqCompletedOffsetToken = numComputedTokens[batchId] \mod B

      seqCompletedOffset=B−seqCompletedOffsetTokenseqCompletedOffset = B - seqCompletedOffsetToken

      seqEndOffset=(L−seqCompletedOffset)mod  BseqEndOffset = (L - seqCompletedOffset) \mod B

      lastFullBlockTokenIndex={L−seqEndOffset−B,seqEndOffset=0L−seqEndOffset,otherwiselastFullBlockTokenIndex = \begin{cases} L - seqEndOffset - B, & seqEndOffset = 0 \\ L - seqEndOffset, & \text{otherwise} \end{cases}

      nBlockToFill=idxLast−idxFirstnBlockToFill = idxLast - idxFirst

      对每个 chunk = 0, 1, ..., nBlockToFill - 1:

      boundaryIdx=lastFullBlockTokenIndex−(nBlockToFill−chunk−1)×BboundaryIdx = lastFullBlockTokenIndex - (nBlockToFill - chunk - 1) \times B

      convStates[cacheIndices[batchId,  idxFirst+chunk]][C−(K−1)+j,  dim]=x′[boundaryIdx+j,  dim],j=0,…,K−2convStates[cacheIndices[batchId, \; idxFirst + chunk]][C-(K-1)+j, \; dim] = x'[boundaryIdx + j, \; dim], \quad j = 0, \dots, K-2

    6. 因果1维卷积

      y[i,dim]=∑k=0K−1w[k,dim]⋅x′[i+k,dim],i=0,1,…,L−1y[i, dim] = \sum_{k=0}^{K-1} w[k, dim] \cdot x'[i + k, dim], \quad i = 0, 1, \dots, L-1

    7. 零填充重置(可选,当convMode == 1 并且 numComputedTokens不为空时)

      resetIdx=min⁡ ⁣(max⁡ ⁣(K−1−numComputedTokens[batchId],  0),  L)resetIdx = \min\!\Big(\max\!\big(K - 1 - numComputedTokens[batchId], \; 0\big), \; L\Big)

      y[i,dim]=0,0≤i<resetIdxy[i, dim] = 0, \quad 0 \leq i < resetIdx

    8. 残差连接(可选)

      y[i,dim]=x[i,dim]+y[i,dim]y[i, dim] = x[i, dim] + y[i, dim]

函数原型

每个算子分为两段式接口,必须先调用 aclnnFusedCausalConv1dGetWorkspaceSize接口获取入参并计算所需workspace大小以及包含了算子计算流程的执行器,再调用aclnnFusedCausalConv1d接口执行计算。

aclnnStatus aclnnFusedCausalConv1dGetWorkspaceSize(
  const aclTensor *x,
  const aclTensor *weight,
  aclTensor       *convStates,
  const aclTensor *queryStartLoc,
  const aclTensor *cacheIndices,
  const aclTensor *initialStateMode,
  const aclTensor *bias,
  const aclTensor *numAcceptedTokens,
  const aclTensor *numComputedTokens,
  const aclTensor *blockIdxFirstScheduledToken,
  const aclTensor *blockIdxLastScheduledToken,
  const aclTensor *initialStateIdx,
  int64_t          activationMode,
  int64_t          padSlotId,
  int64_t          runMode,
  int64_t          maxQueryLen,
  int64_t          residualConnection,
  int64_t          blockSize,
  int64_t          convMode,
  aclTensor       *y,
  uint64_t        *workspaceSize,
  aclOpExecutor  **executor)
aclnnStatus aclnnFusedCausalConv1d(
  void          *workspace,
  uint64_t       workspaceSize,
  aclOpExecutor *executor,
  aclrtStream    stream)

aclnnFusedCausalConv1dGetWorkspaceSize

  • 参数说明

    参数名 输入/输出 描述 使用说明 数据类型 数据格式 维度(shape) 非连续Tensor
    x(aclTensor*) 输入 计算公式中的x,代表输入序列。
    • 不支持空tensor。
    • prefill场景:shape为[cu_seq_len, dim]。
    • decode场景:shape为[cu_seq_len, dim]或[batch, seq_len, dim]。
    FLOAT16、BFLOAT16 ND 2-3
    weight(aclTensor*) 输入 计算公式中的w,代表因果1维卷积核。
    • 不支持空tensor。
    • shape为[K, dim]。
    数据类型与x一致 ND 2 ×
    convStates(aclTensor*) 输入/输出 计算公式中的cacheState,代表缓存状态张量,存储各序列的历史token数据,各序列计算完成后原地更新。
    • 不支持空tensor。
    • shape为[..., stateLen, dim],第 0 维的大小不固定。
    • stateLen == K-1+m,prefill 场景下 m=0,decode、PD混部场景下 m=[0,7]。
    • m 表示投机个数,映射 numAcceptedTokens 参数的值。
    数据类型与x一致 ND 3
    queryStartLoc(aclTensor*) 可选输入 序列起始位置索引,记录各序列在拼接张量x中的起始位置。
    • 不支持空tensor。
    • shape为[batch+1,]。
    • queryStartLoc[i]表示第i个序列的起始偏移。
    • queryStartLoc[0]必须为0,queryStartLoc[-1]必须为cu_seq_len,相邻两个数据不相等。
    INT32 ND 1 ×
    cacheIndices(aclTensor*) 可选输入 缓存索引,指定每个序列对应的缓存状态在convStates中的索引。
    • 不支持空tensor,APC开启时不能为None且必须为二维。
    • 1维shape[batch,]或2维shape[batch, maxNumBlocks], maxNumBlocks范围[1,1024]。
    • 值需要互不相同(除非等于 padSlotId)。
    INT32 ND 1-2 ×
    initialStateMode(aclTensor*) 可选输入 初始状态标志,表示各序列是否使用缓存数据。 不支持此字段。 INT32 ND 1 ×
    bias(aclTensor*) 可选输入 卷积的偏置。 不支持此字段。 数据类型与x一致 ND 1 ×
    numAcceptedTokens(aclTensor*) 可选输入 公式中的numAcceptedTokens,表示每个batch的随机投机数。
    • 1<=numAcceptedTokens中的值<=seqlen,seqlen表示该batch的序列长度。
    • shape为[batch,]。
    INT32 ND 1 ×
    numComputedTokens(aclTensor*) 可选输入 公式中的numComputedTokens,当前batch已经处理的token总数,用于判断初始状态。
    • 首token时使用零初始化缓存;Pangu V2模式下不能为空。
    • shape为[batch,]。
    INT32 ND 1 ×
    blockIdxFirstScheduledToken(aclTensor*) 可选输入 当前 batch 的第一个token对应的block索引。
    • APC开启时不能为空。
    • shape为[batch,]。
    INT32 ND 1 ×
    blockIdxLastScheduledToken(aclTensor*) 可选输入 当前 batch 的最后一个token对应的block索引。
    • APC开启时不能为空。
    • shape为[batch,]。
    INT32 ND 1 ×
    initialStateIdx(aclTensor*) 可选输入 初始索引块的索引。
    • APC开启时不能为空。
    • shape为[batch,]。
    INT32 ND 1 ×
    activationMode(int64_t) 属性 激活函数类型。 不支持此字段。 INT64 - - -
    padSlotId(int64_t) 属性 用于跳过不需要参与计算的batch。
    • 仅支持不参与计算的变长序列在x的开头或结尾。
    • 当cacheIndices[i]==padSlotId时跳过该batch。
    INT64 - - -
    runMode(int64_t) 属性 用于判断是prefill场景或decode场景。 历史遗留接口,不支持此字段。 INT64 - - -
    maxQueryLen(int64_t) 属性 所有 batch 中的最大 seq_len,支持为-1。 - INT64 - - -
    residualConnection(int64_t) 属性 是否做残差连接。
    • 取值为0、1:
      0:不做残差连接;
      1:输出y和输入x相加后输出。
    INT64 - - -
    blockSize(int64_t) 属性 block块的大小。
    • 取值范围大于等于2,典型值128/256。
    INT64 - - -
    convMode(int64_t) 属性 公式中的convMode,支持Qwen3-Next和Pangu V2两种实现。
    • 取值为0、1:
      0:Qwen3-Next 社区版本实现;
      1:Pangu V2 实现。
    INT64 - - -
    y(aclTensor*) 输出 计算公式中的y,代表输出序列。 shape与x一致。 数据类型与x一致 ND 2-3 x
    workspaceSize(int64_t*) 输出 返回用户需要在Device侧申请的workspace大小。 - - - - -
    executor(aclOpExecutor**) 输出 返回op执行器,包含了算子计算流程。 - - - - -
  • 返回值

    aclnnStatus:返回状态码,具体参见aclnn返回码

    第一段接口完成入参校验,出现以下场景时报错:

    返回值 错误码 描述
    ACLNN_ERR_PARAM_NULLPTR 161001 传入的x、weight、convStates、y是空指针。
    ACLNN_ERR_INNER_TILING_ERROR 561002 输入和输出的数据类型不在支持的范围内。
    x、weight、convStates、bias、y的数据类型不一致。
    queryStartLoc、cacheIndices、initialStateMode、numAcceptedTokens、numComputedTokens、blockIdxFirstScheduledToken、blockIdxLastScheduledToken、initialStateIdx的数据类型不一致。
    输入、输出Tensor的shape不在支持的范围内。
    输入的属性不在支持的范围内。
    dim不在指定的取值范围内。

aclnnFusedCausalConv1d

  • 参数说明:

    参数名 输入/输出 描述
    workspace 输入 在Device侧申请的workspace内存地址。
    workspaceSize 输入 在Device侧申请的workspace大小,由第一段接口aclnnFusedCausalConv1dGetWorkspaceSize获取。
    executor 输入 op执行器,包含了算子计算流程。
    stream 输入 指定执行任务的Stream。
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码

约束说明

  • 确定性计算:

    • aclnnFusedCausalConv1d默认确定性实现。
  • 输入shape限制:

    • prefill场景:
      • x支持2维[cu_seq_len, dim]。
      • weight必须是2维[K, dim],其中K固定为3。
      • conv_states必须是3维[..., K-1, dim],第0维大小不固定且大于等于batch。
      • cache_indices为1维[batch, ]或2维[batch, maxNumBlocks],其中1维表示未开启APC,2维表示开启APC。
      • cu_seq_len范围[batch, 1024 * 1024],dim范围[64, 16384]且是16的倍数,batch范围[1, 256],maxNumBlocks范围[1, 1024]。
    • prefill和decode混合场景:
      • x支持2维[cu_seq_len, dim]。
      • weight必须是2维[K, dim],其中K固定为3。
      • conv_states必须是3维[..., K-1+m, dim],第0维大小不固定且大于等于batch。
      • cache_indices为1维[batch, ]或2维[batch, maxNumBlocks],其中1维表示未开启APC,2维表示开启APC。
      • cu_seq_len范围[batch, 1024 * 1024],dim范围[64, 16384]且是16的倍数,batch范围[1, 256],maxNumBlocks范围[1, 1024]。
    • decode场景(固定batch):
      • x支持3维[batch, m+1, dim]。
      • weight必须是2维[K, dim],其中K固定为3。
      • conv_states必须是3维[..., K-1+m, dim],第0维大小不固定且大于等于batch。
      • cache_indices为1维[batch, ]或2维[batch, maxNumBlocks],其中1维表示未开启APC,2维表示开启APC。
      • m范围[0, 7],dim范围[64, 16384]且是16的倍数,batch范围[1, 256],maxNumBlocks范围[1, 1024]。
    • decode场景(变长序列):
      • x支持2维[cu_seq_len, dim]。
      • weight必须是2维[K, dim],其中K固定为3。
      • conv_states必须是3维[..., k-1+m, dim],第0维大小不固定且大于等于batch。
      • cache_indices为1维[batch, ]或2维[batch, maxNumBlocks],其中1维表示未开启APC,2维表示开启APC。
      • cu_seq_len范围[batch, batch*8],每个batch的token个数范围为[1, 8]。dim范围[64, 16384]且是16的倍数,batch范围[1, 256],maxNumBlocks范围[1, 1024]。
  • 输入值域限制:

    • query_start_loc是累计偏移量,取值范围[0, cu_seq_len],长度为batch+1,query_start_loc[i]表示第i个序列的起始偏移,query_start_loc[batch+1]表示最后一个序列的结束位置。
    • blockSize 必须大于等于 2。
    • APC 开启时,必须提供 blockIdxFirstScheduledToken、blockIdxLastScheduledToken 及 initialStateIdx,且满足如下需求,i为batch的索引: - initialStateIdx[i] <= blockIdxFirstScheduledToken[i]+1 - initialStateIdx[i] <= blockIdxLastScheduledToken[i] - blockIdxFirstScheduledToken[i] <= blockIdxLastScheduledToken[i]
    • num_accepted_tokens分为None和非None,非None情况下长度为batch,每个元素取值不超过当前batch的token个数且大于0。
    • Pangu V2 模式(conv_mode = 1)下,首次运行 numComputedTokens 不能为 None。

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例

#include <iostream>
#include <vector>
#include <cmath>
#include "acl/acl.h"
#include "aclnnop/aclnn_fused_causal_conv1d.h"

#define CHECK_RET(cond, return_expr)                                                                                   \
    do {                                                                                                               \
        if (!(cond)) {                                                                                                 \
            return_expr;                                                                                               \
        }                                                                                                              \
    } while (0)

#define LOG_PRINT(message, ...)                                                                                        \
    do {                                                                                                               \
        printf(message, ##__VA_ARGS__);                                                                                \
    } while (0)

int64_t GetShapeSize(const std::vector<int64_t> &shape)
{
    int64_t shapeSize = 1;
    for (auto i : shape) {
        shapeSize *= i;
    }
    return shapeSize;
}

int Init(int32_t deviceId, aclrtStream *stream)
{
    auto ret = aclInit(nullptr);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
    ret = aclrtSetDevice(deviceId);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
    ret = aclrtCreateStream(stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
    return 0;
}

// IEEE 754 fp16 decode (uint16_t → float)
float Fp16ToFloat(uint16_t raw)
{
    uint32_t sign = (raw >> 15) & 0x1;
    uint32_t exp = (raw >> 10) & 0x1F;
    uint32_t mant = raw & 0x3FF;
    float val;
    if (exp == 0) {
        // subnormal or zero
        val = (sign ? -1.0f : 1.0f) * (static_cast<float>(mant) / 1024.0f) * (1.0f / 16384.0f);
    } else if (exp == 31) {
        // infinity or NaN
        val = (mant == 0) ? (sign ? -HUGE_VALF : HUGE_VALF) : NAN;
    } else {
        val = (sign ? -1.0f : 1.0f) * (1.0f + static_cast<float>(mant) / 1024.0f) *
              ldexpf(1.0f, static_cast<int>(exp) - 15);
    }
    return val;
}

template <typename T>
int CreateAclTensor(const std::vector<T> &hostData, const std::vector<int64_t> &shape, void **deviceAddr,
                    aclDataType dataType, aclTensor **tensor, aclFormat format)
{
    auto size = GetShapeSize(shape) * sizeof(T);
    auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
    ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);

    std::vector<int64_t> strides(shape.size(), 1);
    for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; i--) {
        strides[i] = shape[i + 1] * strides[i + 1];
    }

    *tensor = aclCreateTensor(shape.data(), static_cast<int64_t>(shape.size()), dataType, strides.data(), 0, format,
                              shape.data(), static_cast<int64_t>(shape.size()), *deviceAddr);
    return 0;
}

void FreeTensorAndBuffer(aclTensor *tensor, void *deviceAddr)
{
    if (tensor != nullptr) {
        aclDestroyTensor(tensor);
    }
    if (deviceAddr != nullptr) {
        aclrtFree(deviceAddr);
    }
}

int main()
{
    int32_t deviceId = 0;
    aclrtStream stream;
    auto ret = Init(deviceId, &stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret);

    int64_t K = 3;            // kernel width
    int64_t dim = 128;        // feature dimension
    int64_t batch = 4;        // number of sequences
    int64_t numSlots = 8;     // total cache slots (>= batch for non-APC)
    int64_t stateLen = K - 1; // cache state length per slot (= 2)
    // prefill: seq_lens = [5, 3, 7, 4], cuSeqLen = 19
    int64_t cuSeqLen = 19;

    // ---- Tensor shapes ----
    std::vector<int64_t> xShape = {cuSeqLen, dim};
    std::vector<int64_t> weightShape = {K, dim};
    std::vector<int64_t> convStatesShape = {numSlots, stateLen, dim};
    std::vector<int64_t> queryStartLocShape = {batch + 1};
    std::vector<int64_t> cacheIndicesShape = {batch}; // 1D: non-APC
    std::vector<int64_t> initialStateModeShape = {batch};
    std::vector<int64_t> biasShape = {dim};
    std::vector<int64_t> numAcceptedTokensShape = {batch};
    std::vector<int64_t> yShape = {cuSeqLen, dim};

    // ---- Host data ----
    std::vector<uint16_t> hostX(cuSeqLen * dim, 0x3C00); // fp16 1.0

    // weight: pass-through at k=0, zero elsewhere — 用于简化 golden 验证
    std::vector<uint16_t> hostWeight(K * dim, 0);
    for (int64_t d = 0; d < dim; d++) {
        hostWeight[0 * dim + d] = 0x3C00; // k=0: 1.0
    }

    // conv_states: zero-initialized, shape [8, 2, 128]
    std::vector<uint16_t> hostConvStates(numSlots * stateLen * dim, 0);

    // query_start_loc: [0, 5, 8, 15, 19]
    std::vector<int32_t> hostQueryStartLoc = {0, 5, 8, 15, 19};

    // cache_indices: each batch uses a distinct cache slot
    std::vector<int32_t> hostCacheIndices = {0, 3, 1, 5};

    // initial_state_mode: 1=from cache, 0=zero-init, 2=from cache (MTP variant)
    std::vector<int32_t> hostInitialStateMode = {1, 0, 2, 1};

    // bias: zero (unused)
    std::vector<uint16_t> hostBias(dim, 0);

    // num_accepted_tokens: 0 for prefill (no speculative decoding)
    std::vector<int32_t> hostNumAcceptedTokens(batch, 0);

    // y output buffer
    std::vector<uint16_t> hostY(cuSeqLen * dim, 0);

    // ---- Device pointers and aclTensors ----
    void *xDev = nullptr, *weightDev = nullptr, *convStatesDev = nullptr;
    void *queryStartLocDev = nullptr, *cacheIndicesDev = nullptr;
    void *initialStateModeDev = nullptr, *biasDev = nullptr;
    void *numAcceptedTokensDev = nullptr, *yDev = nullptr;

    aclTensor *xTensor = nullptr, *weightTensor = nullptr, *convStatesTensor = nullptr;
    aclTensor *queryStartLocTensor = nullptr, *cacheIndicesTensor = nullptr;
    aclTensor *initialStateModeTensor = nullptr, *biasTensor = nullptr;
    aclTensor *numAcceptedTokensTensor = nullptr, *yTensor = nullptr;

    // ---- Create required tensors ----
    ret = CreateAclTensor(hostX, xShape, &xDev, ACL_FLOAT16, &xTensor, ACL_FORMAT_ND);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    ret = CreateAclTensor(hostWeight, weightShape, &weightDev, ACL_FLOAT16, &weightTensor, ACL_FORMAT_ND);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    ret =
        CreateAclTensor(hostConvStates, convStatesShape, &convStatesDev, ACL_FLOAT16, &convStatesTensor, ACL_FORMAT_ND);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    ret = CreateAclTensor(hostQueryStartLoc, queryStartLocShape, &queryStartLocDev, ACL_INT32, &queryStartLocTensor,
                          ACL_FORMAT_ND);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    ret = CreateAclTensor(hostCacheIndices, cacheIndicesShape, &cacheIndicesDev, ACL_INT32, &cacheIndicesTensor,
                          ACL_FORMAT_ND);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    ret = CreateAclTensor(hostInitialStateMode, initialStateModeShape, &initialStateModeDev, ACL_INT32,
                          &initialStateModeTensor, ACL_FORMAT_ND);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    ret = CreateAclTensor(hostBias, biasShape, &biasDev, ACL_FLOAT16, &biasTensor, ACL_FORMAT_ND);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    ret = CreateAclTensor(hostNumAcceptedTokens, numAcceptedTokensShape, &numAcceptedTokensDev, ACL_INT32,
                          &numAcceptedTokensTensor, ACL_FORMAT_ND);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    ret = CreateAclTensor(hostY, yShape, &yDev, ACL_FLOAT16, &yTensor, ACL_FORMAT_ND);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    int64_t activationMode = 0; // 0: None
    int64_t padSlotId = -1;     // -1: no pad-slot skipping
    int64_t runMode = 0;        // 0: prefill
    int64_t maxQueryLen = cuSeqLen;
    int64_t residualConnection = 0; // 0: no residual
    int64_t blockSize = 0;          // 0: non-APC
    int64_t convMode = 0;           // 0: Qwen/Pangu7B

    uint64_t workspaceSize = 0;
    aclOpExecutor *executor = nullptr;

    ret = aclnnFusedCausalConv1dGetWorkspaceSize(xTensor, weightTensor, convStatesTensor, queryStartLocTensor,
                                                 cacheIndicesTensor, initialStateModeTensor, biasTensor,
                                                 numAcceptedTokensTensor,
                                                 nullptr, // num_computed_tokens (non-APC: nullptr)
                                                 nullptr, // block_idx_first_scheduled_token
                                                 nullptr, // block_idx_last_scheduled_token
                                                 nullptr, // initial_state_idx
                                                 activationMode, padSlotId, runMode, maxQueryLen, residualConnection,
                                                 blockSize, convMode, yTensor, &workspaceSize, &executor);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnFusedCausalConv1dGetWorkspaceSize failed. ERROR: %d\n", ret);
              return ret);

    void *workspaceAddr = nullptr;
    if (workspaceSize > 0) {
        ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret);
    }

    ret = aclnnFusedCausalConv1d(workspaceAddr, workspaceSize, executor, stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnFusedCausalConv1d failed. ERROR: %d\n", ret); return ret);

    ret = aclrtSynchronizeStream(stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret);

    auto ySize = GetShapeSize(yShape);
    std::vector<uint16_t> yResult(ySize, 0);
    ret = aclrtMemcpy(yResult.data(), ySize * sizeof(uint16_t), yDev, ySize * sizeof(uint16_t),
                      ACL_MEMCPY_DEVICE_TO_HOST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy y from device to host failed. ERROR: %d\n", ret); return ret);

    auto csSize = GetShapeSize(convStatesShape);
    std::vector<uint16_t> csResult(csSize, 0);
    ret = aclrtMemcpy(csResult.data(), csSize * sizeof(uint16_t), convStatesDev, csSize * sizeof(uint16_t),
                      ACL_MEMCPY_DEVICE_TO_HOST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy convStates from device to host failed. ERROR: %d\n", ret);
              return ret);

    LOG_PRINT("=== y output (first 10 elements) ===\n");
    for (int64_t i = 0; i < 10 && i < ySize; i++) {
        LOG_PRINT("  y[%lld] = %f\n", (long long)i, (double)Fp16ToFloat(yResult[i]));
    }
    LOG_PRINT("  ...\n");

    LOG_PRINT("=== convStates output (first 10 elements) ===\n");
    for (int64_t i = 0; i < 10 && i < csSize; i++) {
        LOG_PRINT("  convStates[%lld] = %f\n", (long long)i, (double)Fp16ToFloat(csResult[i]));
    }
    LOG_PRINT("  ...\n");

    FreeTensorAndBuffer(xTensor, xDev);
    FreeTensorAndBuffer(weightTensor, weightDev);
    FreeTensorAndBuffer(convStatesTensor, convStatesDev);
    FreeTensorAndBuffer(queryStartLocTensor, queryStartLocDev);
    FreeTensorAndBuffer(cacheIndicesTensor, cacheIndicesDev);
    FreeTensorAndBuffer(initialStateModeTensor, initialStateModeDev);
    FreeTensorAndBuffer(biasTensor, biasDev);
    FreeTensorAndBuffer(numAcceptedTokensTensor, numAcceptedTokensDev);
    FreeTensorAndBuffer(yTensor, yDev);

    if (workspaceSize > 0) {
        aclrtFree(workspaceAddr);
    }

    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();

    LOG_PRINT("Test completed successfully!\n");
    return 0;
}