文件最后提交记录最后更新时间
4 天前
4 天前
4 天前
4 天前
4 天前
21 小时前
21 小时前
4 天前
4 天前
README.md

CausalConv1d

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 x
Atlas A2 训练系列产品/Atlas A2 推理系列产品 x
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:完成因果一维卷积(Causal Conv1d)计算,支持前向计算(prefill / chunk-prefill)和状态更新(decode / update)两种运行模式,模式由输入形状自动推断。

  • 计算公式:

    Causal Conv1d 是一种因果一维卷积算子,常用于序列建模中。在每个时间步 tt,根据当前输入 xtx_t、卷积权重 ww 和历史状态,计算卷积输出 yty_t

    yt=Activation(∑j=0W−1wj⋅xt−j+b)y_t = \text{Activation}\left(\sum_{j=0}^{W-1} w_j \cdot x_{t-j} + b\right)

    其中,WW 为卷积核宽度(支持2、3、4),wjw_j 为卷积权重,bb 为偏置(可选),Activation\text{Activation} 为激活函数(可选,SiLU)。当 activation_mode="none" 时不使用激活函数,activation_mode="silu" 时使用 SiLU 激活函数。

    算子同时维护卷积状态 conv_states,用于在增量推理时缓存历史输入,实现高效的状态更新。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x 输入 输入序列,公式中的x。 BFLOAT16、FLOAT16 ND
weight 输入 卷积权重,公式中的w。 同x ND
conv_states 输入/输出
  • 卷积状态,缓存历史输入用于因果卷积计算。
  • 各序列计算完成后原地更新。
同x ND
bias 可选输入 偏置,公式中的b。若不提供则默认为0。 同x ND
query_start_loc 可选输入 序列起始位置索引,记录各序列在拼接张量x中的起始位置。queryStartLoc[0]必须为0,queryStartLoc[-1]必须为cu_seq_len。 INT32 ND
cache_indices 可选输入 缓存索引,指定每个序列对应的缓存状态在conv_states中的索引。不传时使用恒等映射。 INT32 ND
initial_state_mode 可选输入 初始状态标志。1:使用缓存的conv_states作为历史,0:零初始化历史(冷启动)。 INT32 ND
num_accepted_tokens 可选输入 每个序列接受的token数量,用于投机解码。 INT32 ND
activation_mode 可选属性
  • 激活函数类型。"silu":使用SiLU激活函数,"none":不使用激活函数。
  • 默认值为"silu"。
STR -
null_block_id 可选属性
  • 无效缓存槽位的标记ID,用于跳过不需要计算的序列。
  • 默认值为0。
INT64 -
y 输出 卷积输出,公式中的y。 同x ND

约束说明

调用说明

调用方式 样例代码 说明
aclnn接口 (prefill) test_aclnn_causal_conv1d_fn.cpp 通过 aclnnCausalConv1dFn 调用 prefill 模式
aclnn接口 (update) test_aclnn_causal_conv1d_update.cpp 通过 aclnnCausalConv1dUpdate 调用 update 模式