sum_lstm Operator
功能说明
sum_lstm算子是Arctic-Inference框架中基于LSTM的speculator的核心机制。它具有以下特点:
- 输入融合: 将主要状态输入与额外输入信号进行加权融合
- RMS归一化: 使用RMSNorm代替LayerNorm,提高计算效率
- GELU激活: 使用近似GELU激活函数
- 门控机制: 标准的LSTM门控逻辑,控制信息的遗忘、输入和输出
- 状态更新: 通过遗忘门和输入门更新细胞状态
- 输出生成: 通过输出门控制最终的隐藏状态输出
该算子主要用于序列预测和生成任务中的speculative decoding,加速大语言模型的推理过程。
数学公式
Sum LSTM算子的计算过程如下:
-
输入融合:
fused = states_4d + alpha * z4_4d -
门控分割: 将fused按最后一维分割为4个部分,每个部分大小为
D_GATE:pre_f, pre_i, pre_o, pre_c = split(fused, 4, dim=-1) -
门控激活:
f_gate = sigmoid(pre_f) # 遗忘门 i_gate = sigmoid(pre_i) # 输入门 o_gate = sigmoid(pre_o) # 输出门 -
细胞候选处理:
c_cand_norm = RMSNorm(pre_c, eps_cell) c_cand_norm = c_cand_norm * w_cell + b_cell (如果权重存在) c_act = GELU(c_cand_norm) -
细胞状态更新:
c_new = prev_cell * f_gate + c_act * i_gate -
隐藏状态处理:
h_temp = RMSNorm(c_new, eps_state) h_temp = h_temp * w_state + b_state (如果权重存在) h_act = GELU(h_temp) -
最终输出:
h_new = h_act * o_gate
其中:
- RMSNorm(x, eps):
x * rsqrt(mean(x^2) + eps) - GELU(x):
x * sigmoid(1.702 * x)(近似实现)
函数原型
def sum_lstm_kernel(
states_4d: pypto.Tensor((BATCH_SIZE, D_GATE_4), pypto.DT_FP16),
z4_4d: pypto.Tensor((BATCH_SIZE, D_GATE_4), pypto.DT_FP16),
prev_cell: pypto.Tensor((BATCH_SIZE, D_GATE), pypto.DT_FP16),
w_cell: pypto.Tensor((D_GATE,), pypto.DT_FP16),
b_cell: pypto.Tensor((D_GATE,), pypto.DT_FP16),
w_state: pypto.Tensor((D_GATE,), pypto.DT_FP16),
b_state: pypto.Tensor((D_GATE,), pypto.DT_FP16),
config: LstmConfig,
h_out: pypto.Tensor((BATCH_SIZE, D_GATE), pypto.DT_FP16),
c_out: pypto.Tensor((BATCH_SIZE, D_GATE), pypto.DT_FP16)
) -> None
参数说明
- states_4d: 张量,形状为
(BATCH_SIZE, D_GATE_4),数据类型为 FP16。表示LSTM的状态输入,包含4个门控信号(遗忘门、输入门、输出门、细胞候选)。 - z4_4d: 张量,形状为
(BATCH_SIZE, D_GATE_4),数据类型为 FP16。表示额外的输入信号,与states_4d融合。 - prev_cell: 张量,形状为
(BATCH_SIZE, D_GATE),数据类型为 FP16。表示前一时刻的细胞状态。 - w_cell: 张量,形状为
(D_GATE,),数据类型为 FP16。细胞路径的权重参数。 - b_cell: 张量,形状为
(D_GATE,),数据类型为 FP16。细胞路径的偏置参数。 - w_state: 张量,形状为
(D_GATE,),数据类型为 FP16。状态路径的权重参数。 - b_state: 张量,形状为
(D_GATE,),数据类型为 FP16。状态路径的偏置参数。 - config: LstmConfig 对象,包含超参数:
alpha: 融合系数,默认 0.1eps_cell: 细胞RMSNorm的epsilon,默认 1e-6eps_state: 状态RMSNorm的epsilon,默认 1e-6
返回值说明
- h_out: 张量,形状为
(BATCH_SIZE, D_GATE),数据类型为 FP16。表示新的隐藏状态输出。 - c_out: 张量,形状为
(BATCH_SIZE, D_GATE),数据类型为 FP16。表示新的细胞状态输出。
调用示例
- 算子源码执行参考test_sum_lstm.py