README.md

deepseek V3.2 样例 (Examples)

本目录包含了一系列 PyPTO deepseek V3.2 EXP 的开发样例代码,我们对 DeepSeek-V3.2-Exp 进行了拆解,交付了五个算子:mla prolog, lightning indexer prolog, sparese flash attention, mla_indexer_prolog和lightning indexer。

参数说明/约束

  • shape 格式字段含义说明
    字段名 英文全称/含义 取值规则与说明
    b Batch(输入样本批量大小) 取值范围:decode场景1~128 prefill场景固定为1
    s1 query Seq-Length 取值范围:decode场景 1~4 prefill场景1~1K
    s2 key Seq-Length 取值范围:1~128K
    h Head-Size(隐藏层大小) 取值固定为:7168
    n_q(n1) query 的 Head-Num(多头数) 取值范围:128
    n_kv(n2) kv 的 head 数 取值范围:1
    kv_lora_rank kv 低秩矩阵维度 取值范围:512
    rope_dim qk 位置编码维度 取值范围:64
    v_head_dim value 的头维度 取值范围:128
    q_head_dim query 的头维度 取值范围:192
    q_lora_rank query 低秩矩阵维度 取值范围:1536
    idx_n_heads indexer里query的head num 取值固定为:64
    idx_head_dim indexer里query的头维度 取值固定为:128
    selected_count topk选择的个数 取值固定为:2048
    block_num PagedAttention 场景下per-tile量的块数 取值为计算 B*Skv/BlockSize 的结果后向上取整(Skv 表示 kv 的序列长度,允许取 0)
    block_size PagedAttention 场景下的块大小 取值范围:128
    t BS 合轴后的大小 取值范围:b * s1

mla_polog_quant

功能说明

MLA Prolog 模块将hidden状态 X\bold{X} 转换为查询投影 q\bold{q}、键投影 k\bold{k} 和值投影 v\bold{v},其结构与 DeepSeek V3 的架构一致。在解码阶段,采用了权重吸收技术。

计算公式

RmsNorm公式

RmsNorm(x)=γ⋅xiRMS(x)\text{RmsNorm}(x) = \gamma \cdot \frac{x_i}{\text{RMS}(x)}

RMS(x)=1N∑i=1Nxi2+ϵ\text{RMS}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon}

路径1:标准Query计算

包括下采样、RmsNorm和两次上采样:

cQ=RmsNorm(x⋅WDQ)c^Q = RmsNorm(x \cdot W^{DQ})

qC=cQ⋅WUQq^C = c^Q \cdot W^{UQ}

qN=qC⋅WUKq^N = q^C \cdot W^{UK}

路径2:位置编码Query计算

对Query进行ROPE旋转位置编码:

qR=ROPE(cQ⋅WQR)q^R = ROPE(c^Q \cdot W^{QR})

路径3:标准Key计算

包括下采样、RmsNorm,将计算结果存入Cache:

cKV=RmsNorm(x⋅WDKV)c^{KV} = RmsNorm(x \cdot W^{DKV})

kC=Cache(cKV)k^C = Cache(c^{KV})

路径4:位置编码Key计算

对Key进行ROPE旋转位置编码,并将结果存入Cache:

kR=Cache(ROPE(x⋅WKR))k^R = Cache(ROPE(x \cdot W^{KR}))

函数原型

def mla_prolog_quant_compute(token_x, w_dq, w_uq_qr, dequant_scale, w_uk, w_dkv_kr, gamma_cq, gamma_ckv, cos,
		sin, cache_index, kv_cache, kr_cache, k_scale_cache, q_norm_out, q_norm_scale_out, query_nope_out,
   		query_rope_out, kv_cache_out, kr_cache_out, k_scale_cache_out, epsilon_cq, epsilon_ckv, cache_mode,
    	tile_config, rope_cfg):

参数说明

  • token_xTensor):公式中用于计算Query和Key的输入tensor,不支持非连续的 Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[t, h]。
  • w_dqTensor):公式中用于计算Query的下采样权重矩阵WDQW^{DQ},不支持非连续的 Tensor。数据格式支持NZ,数据类型支持bfloat16,shape为[h, q_lora_rank]。
  • w_uq_qrTensor):公式中用于计算Query的上采样权重矩阵WUQW^{UQ}和位置编码权重矩阵WQRW^{QR},不支持非连续的 Tensor,数据格式支持NZ,数据类型支持int8,shape为[q_lora_rank, n_q * q_head_dim]。
  • dequant_scaleTensor):用于MatmulQcQr矩阵乘后w_uq_qr反量化操作的per-channel参数,不支持非连续的 Tensor。数据格式支持ND,数据类型支持float,shape为[n_q*q_head_dim, 1]。
  • w_ukTensor):公式中用于计算Key的上采样权重WUKW^{UK},不支持非连续的 Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[n_q, qk_nope_head_dim, kv_lora_rank]。
  • w_dkv_krTensor):公式中用于计算Key的下采样权重矩阵WDKVW^{DKV}和位置编码权重矩阵WKRW^{KR},不支持非连续的 Tensor,数据格式支持NZ,数据类型支持bfloat16,shape为[h, kv_lora_rank+rope_dim]。
  • gamma_cqTensor):计算cQc^Q的RmsNorm公式中的γ\gamma参数,不支持非连续的 Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[q_lora_rank]。
  • gamma_ckvTensor):计算cKVc^{KV}的RmsNorm公式中的γ\gamma参数,不支持非连续的 Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[kv_lora_rank]。
  • cosTensor):用于计算旋转位置编码的余弦参数矩阵,不支持非连续的 Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[t, rope_dim]。
  • sinTensor):用于计算旋转位置编码的正弦参数矩阵,不支持非连续的 Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[t, rope_dim]。
  • cache_indexTensor):用于存储kv_cache和kr_cache的索引,不支持非连续的 Tensor,数据格式支持ND,数据类型支持int64,shape为[t]。
  • kv_cacheTensor):用于cache索引的aclTensor,计算结果原地更新(对应公式中的kCk^C),不支持非连续的 Tensor。数据格式支持ND,数据类型支持int8,cache_mode为"PA_BSND",shape为[block_num, block_size, n_kv, kv_lora_rank]。
  • kr_cacheTensor):用于key位置编码的cache,计算结果原地更新(对应公式中的kRk^R),不支持非连续的 Tensor。数据格式支持ND,cache_mode为"PA_BSND",数据类型支持bfloat16,cache_mode为"PA_BSND"、shape为[block_num, block_size, n_kv, rope_dim]。
  • k_scale_cacheTensor):表示 key 反量化因子的缓存,必选参数,不支持非连续的 Tensor,数据格式支持 ND,cache_mode为"PA_BSND",数据类型支持float,shape为[block_num, block_size, n_kv, 4]。
  • epsilon_cqfloat):计算cQc^Q的RmsNorm公式中的ϵ\epsilon参数。用户未特意指定时,建议传入1e-05,仅支持double类型,默认值为1e-05。
  • epsilon_ckvfloat):计算cKVc^{KV}的RmsNorm公式中的ϵ\epsilon参数。用户未特意指定时,建议传入1e-05,仅支持double类型,默认值为1e-05。
  • cache_modestr):表示kv_cache的模式,支持"PA_BSND"。
  • tile_configclass MlaTileConfig):表示tile切分配置。
  • rope_cfgclass RopeTileShapeConfig):表示rope tile切分配置。

返回值说明

  • q_norm_outTensor):Query做RmsNorm_cq后的输出tensor(对应qCq^C),不支持非连续的 Tensor。数据格式支持ND,数据类型支持int8,shape为[t, q_lora_rank]。
  • q_norm_scale_outTensor):Query做RmsNorm_cq后的反量化参数,不支持非连续的 Tensor。数据格式支持ND,数据类型支持float,shape为[t, 1]。
  • q_nope_outTensor):公式中Query的输出tensor(对应qNq^N),不支持非连续的 Tensor。数据格式支持ND,数据类型支持bfloat16,shape为[t, n_q, kv_lora_rank]。
  • q_rope_outTensor):公式中Query位置编码的输出tensor(对应qRq^R),不支持非连续的 Tensor。数据格式支持ND,数据类型支持bfloat16,shape为[t, n_q, rope_dim]。
  • kv_cache_outTensor):Key输出到kv_cache中的tensor(对应kCk^C),不支持非连续的 Tensor。数据格式支持ND,cache_mode为"PA_BSND",数据类型支持int8,shape为[block_num, block_size, n_kv, kv_lora_rank]。
  • kr_cache_outTensor):Key的位置编码输出到kr_cache中的tensor(对应kRk^R),不支持非连续的 Tensor。数据格式支持ND,cache_mode为"PA_BSND",数据类型支持bfloat16,shape为[block_num, block_size, n_kv, qk_rope_dim]。
  • k_scale_cache_outTensor):Key做反量化后输出的反量化参数,不支持非连续的 Tensor。数据格式支持ND,数据类型支持float,cache_mode为"PA_BSND",shape为[block_num, block_size, n_kv, 4]。

调用示例

lightning_indexer_prolog

功能说明

用于 Deepseek IndexerAttention 中,计算 Lightning Indexer 所需要的 query,key 和 weights。 Indexer Prolog 的量化策略如下:Q_b_proj 使用 W8A8 量化,其他 Linear 均不量化;query 使用 A8 量化,key(cache) 使用 C8 量化;反量化因子以 FP16 存储;weights 以 FP16 存储;

计算公式

Query 的计算公式如下:

Q 的计算采用了动态的 Per-Token-Head 量化,其中 Hadamard 变换通过矩阵右乘 hadamard_q 实现。而 q,wqb\bold{q}, \bold{w}_{qb} 均是 Int8 类型。

q,qscale=DynamicQuant(Hadamard(RoPE(DeQuant(q⋅wqb))))\bold{q}, \bold{q}_{scale} = \text{DynamicQuant}(\text{Hadamard}(\text{RoPE}(\text{DeQuant}(\bold{q} \cdot \bold{w}_{qb}))))

Key(cache) 的计算公式如下:

Cache 的计算同样采用了动态的 Per-Token-Head 量化,其中 Hadamard 变换通过矩阵右乘 hadamard_k 实现。

k,kscale=DynamicQuant(Hadamard(RoPE(LayerNorm(x⋅wk))))\bold{k}, \bold{k}_{scale} = \text{DynamicQuant}(\text{Hadamard}(\text{RoPE}(\text{LayerNorm}(\bold{x} \cdot \bold{w}_k))))

Weights 的计算公式如下:

Weights 的计算没有采用量化,同时需要最后转化为 FP16 数据类型,供后续的 Lightning Indexer 计算使用。

weight=(x⋅wproj)∗scale\bold{weight} = (\bold{x} \cdot \bold{w}_{proj}) * \text{scale}

函数原型

def lightning_indexer_prolog_quant_compute(x_in, q_norm_in, q_norm_scale_in, w_qb_in, w_qb_scale_in, wk_in, w_proj_in,
				ln_gamma_k_in, ln_beta_k_in, cos_idx_rope_in, sin_idx_rope_in, hadamard_q_in, hadamard_k_in, k_int8_in, k_scale_in,
                k_cache_index_in, q_int8_out, q_scale_out, k_int8_out,k_scale_out, weights_out, attrs, configs):

参数说明

  • x_inTensor):表示 hidden 状态token_x,必选参数,不支持非连续的Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[t, h]。
  • q_norm_inTensor):表示经过 rmsnorm 后量化的 query,必选参数,不支持非连续的Tensor,数据格式支持ND,数据类型支持int8,shape为[t, q_lora_rank]。
  • q_norm_scale_inTensor):表示 query 的反量化因子,必选参数,不支持非连续的Tensor,数据格式支持ND,数据类型支持float32,shape为[t, 1]。
  • wq_b_inTensor):表示 query 的权重,必选参数,不支持非连续的Tensor,数据格式支持NZ,数据类型支持int8,shape为[q_lora_rank, idx_n_heads*idx_head_dim]。
  • wq_qb_scale_inTensor):表示 query 的权重反量化因子,必选参数,不支持非连续的Tensor,数据格式支持ND,数据类型支持float32,shape为[idx_n_heads*idx_head_dim, 1]。
  • wk_inTensor):表示 key 的权重,必选参数,不支持非连续的Tensor,数据格式支持NZ,数据类型支持bfloat16,shape为[h, idx_head_dim]。
  • w_proj_inTensor):表示 weights 的权重,必选参数,不支持非连续的Tensor,数据格式支持NZ,数据类型支持bfloat16,shape为[h, idx_n_heads]。
  • ln_gamma_k_inTensor):表示 key 的 layernorm 缩放,必选参数,不支持非连续的Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[idx_head_dim]。
  • ln_beta_k_inTensor):表示 key 的 layernorm 偏移,必选参数,不支持非连续的Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[idx_head_dim]。
  • cos_idx_rope_inTensor):表示用于 RoPE 的 cos,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持bfloat16,shape为[t, rope_dim]。
  • sin_idx_rope_inTensor):表示用于 RoPE 的 sin,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持bfloat16,shape为[t, rope_dim]。
  • hadamard_q_inTensor):表示用于 query Hadamard 变换的权重矩阵,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持bfloat16,shape为[idx_head_dim, idx_head_dim]。
  • hadamard_k_inTensor):表示用于 key Hadamard 变换的权重矩阵,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持bfloat16,shape为[idx_head_dim, idx_head_dim]。
  • k_int8_inTensor):表示 key 的缓存(k_cache),必选参数,不支持非连续的 Tensor,数据格式支持 ND,cache_mode为"PA_BSND",数据类型支持int8,shape为[block_num, block_size, n_kv, idx_head_dim]。
  • k_scale_inTensor):表示 key 反量化因子的缓存,必选参数,不支持非连续的 Tensor,数据格式支持 ND,cache_mode为"PA_BSND",数据类型支持float16,shape为[block_num, block_size, n_kv, 1]。
  • k_cache_index_inTensor):表示更新 key 缓存的位置,必选参数,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持int64,shape为[t]。
  • attrs.layernorm_epsilon_kfloat):表示 key layernorm 防除 0 系数,必选参数,数据类型支持float32
  • attrs.layout_querystr):可选参数,用于标识输入query的数据排布格式,默认值"TND"。当前仅支持 "TND"。
  • attrs.layout_keystr):可选参数,用于标识输入key的数据排布格式,默认值"PA_BSND"。当前仅支持 "PA_BSND"。
  • configsclass IndexerPrologQuantConfigs):表示tile切分配置。

返回值说明

  • q_int8_outTensor):公式中 query 的输出 tensor,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持int8,shape为[t, idx_n_heads, idx_head_dim]。
  • q_scale_outTensor):公式中 query 反量化因子的输出 tensor,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持float16,shape为[t, idx_n_heads, 1]。
  • k_int8_outTensor):表示 key 的缓存(k_cache)的输出 tensor,不支持非连续的 Tensor,数据格式支持 ND,cache_mode为"PA_BSND",数据类型支持int8,shape为[block_num, block_size, n_kv, idx_head_dim]。
  • k_scale_outTensor):表示 key 反量化因子的缓存的输出 tensor,不支持非连续的 Tensor,cache_mode为"PA_BSND",数据格式支持 ND,数据类型支持float16,shape为[block_num, block_size, n_kv, 1]。
  • weights_outTensor):公式中 weights 的输出 tensor,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持float16,shape为[t, idx_n_heads]。

调用示例

sparse_flash_attention_quant

功能说明

对于每个查询 token xi\bold{x}_i,索引模块会为每个键值缓存项(表示键值对或 MLA 潜在表示)计算一个相关性得分 Ii,jI_{i,j}。然后,通过将注意力机制应用于查询 token xi\bold{x}_i 以及得分最高的前 kk 个缓存项,来计算输出 oi\bold{o}_i

计算公式

oi=Attn(xi,{cj∣j∈Top-k(Ii,:)})\bold{o}_i = \text{Attn}(\bold{x}_i, \{\bold{c}_j | j \in \text{Top-k}(\bold{I}_{i, :})\})

函数原型

def sparse_flash_attention_quant_compute(query_nope, query_rope, key_nope_2d, key_rope_2d, k_nope_scales,
		topk_indices, block_table, kv_act_seqs, attention_out, nq, n_kv, softmax_scale, topk, block_size,
        max_blocknum_perbatch, tile_config):

参数说明

  • query_nopeTensor):必选参数,表示MLA结构中的query的rope信息,不支持非连续的 Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[t * n_q, kv_lora_rank]。
  • query_ropeTensor):必选参数,表示MLA结构中的query的nope信息,不支持非连续的 Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[t * n_q, rope_dim]。
  • key_nope_2dTensor):必选参数,表示MLA结构中的key的rope信息,不支持非连续的 Tensor,数据格式支持ND,数据类型支持int8,shape为[block_num * block_size, kv_lora_rank]。
  • key_rope_2dTensor):必选参数,表示MLA结构中的key的nope信息,不支持非连续的 Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[block_num * block_size, rope_dim]。
  • k_nope_scalesTensor):必选参数,表示k_nope的反量化缩放因子,必选参数,不支持非连续的 Tensor,数据格式支持ND,数据类型支持float,shape为[block_num * block_size, 4]。
  • topk_indicesTensor):必选参数,表示每个token选出的topk索引,必选参数,不支持非连续的 Tensor,数据格式支持ND,数据类型支持int32,shape为[t, n_kv * selected_count]。
  • block_tableTensor):必选参数,表示PageAttention中KV存储使用的block映射表,不支持非连续的 Tensor,数据格式支持ND,数据类型支持int32,shape为[b, s2_max/block_size],其中第二维表示长度不小于所有batch中最大的s2对应的block数量,即s2_max / block_size向上取整。
  • kv_act_seqsTensor):必选参数,数据格式支持ND,表示不同Batch中keyvalue的有效token数,数据类型支持int32,shape为[b]。
  • nqint):必选参数,代表缩放系数,作为query和key矩阵乘后Muls的scalar值,数据类型支持float。
  • n_kvint):必选参数,代表缩放系数,作为query和key矩阵乘后Muls的scalar值,数据类型支持float。
  • softmax_scalefloat):必选参数,代表缩放系数,作为query和key矩阵乘后Muls的scalar值,数据类型支持float。
  • topkint):必选参数,代表选取的token个数,数据类型支持int。
  • block_sizeint):必选参数,代表sparse阶段的block大小,数据类型支持int。
  • max_blocknum_perbatchint):必选参数,每个batch最大的blocksize数量,数据类型支持int。
  • tile_configclass SaTileShapeConfig):TileShapeConfig配置结构体,表示tile切分配置,配置项数据类型支持int。

返回值说明

  • attention_outTensor):公式中的输出。数据格式支持ND,数据类型支持bfloat16,输出shape[b, s1, n_q, kv_lora_rank]。

调用示例

sparse_attention_antiquant

功能说明

sa_antiquant是在sfa_quant基础上做的 存8算16 优化。在sfa_quant场景中,key_nope_2d,key_rope_2d 和 k_nope_scales 分别是 int8,bf16 和 fp32 类型;在后续 attention 的计算上,会离散地存储这三个 tensor,需要调三次离散访存指令去分别调用进行反量化和 concat;而 sa_antiquant 会将同一个 token 的 nope,rope 和 nope_scale 按尾轴合并在一起,仅需一条离散访存指令,总计可以节省 b * s * topk 次离散访存命令,节省搬运指令,提升搬运效率。

函数原型

def sparse_attention_antiquant_compute(query_nope, query_rope, nope_cache, topk_indices, block_table,
		kv_act_seqs, attention_out, nq, n_kv, softmax_scale, topk, block_size, max_blocknum_perbatch,
        tile_config):

参数说明

  • query_nopeTensor):必选参数,表示MLA结构中的query的rope信息,不支持非连续的 Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[t * n_q, kv_lora_rank]。
  • query_ropeTensor):必选参数,表示MLA结构中的query的nope信息,不支持非连续的 Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[t * n_q, rope_dim]。
  • nope_cacheTensor):必选参数,表示MLA结构中的key的反量化缩放因子,不支持非连续的 Tensor,数据格式支持ND,数据类型支持int8,shape为[block_num * block_size, kv_lora_rank + rope_dim * 2 + 4 * scale_size],其中scale_size=4。
  • topk_indicesTensor):必选参数,表示每个token选出的topk索引,必选参数,不支持非连续的 Tensor,数据格式支持ND,数据类型支持int32,shape为[t, n_kv * selected_count]。
  • block_tableTensor):必选参数,表示PageAttention中KV存储使用的block映射表,不支持非连续的 Tensor,数据格式支持ND,数据类型支持int32,shape为[b, s2_max/block_size],其中第二维表示长度不小于所有batch中最大的s2对应的block数量,即s2_max / block_size向上取整。
  • kv_act_seqsTensor):必选参数,数据格式支持ND,表示不同Batch中keyvalue的有效token数,数据类型支持int32,shape为[b]。
  • nqint):必选参数,代表缩放系数,作为query和key矩阵乘后Muls的scalar值,数据类型支持float。
  • n_kvint):必选参数,代表缩放系数,作为query和key矩阵乘后Muls的scalar值,数据类型支持float。
  • softmax_scalefloat):必选参数,代表缩放系数,作为query和key矩阵乘后Muls的scalar值,数据类型支持float。
  • topkint):必选参数,代表选取的token个数,数据类型支持int。
  • block_sizeint):必选参数,代表sparse阶段的block大小,数据类型支持int。
  • max_blocknum_perbatchint):必选参数,每个batch最大的blocksize数量,数据类型支持int。
  • tile_configclass SaTileShapeConfig):TileShapeConfig配置结构体,表示tile切分配置,配置项数据类型支持int。

返回值说明

  • attention_outTensor):公式中的输出。数据格式支持ND,数据类型支持bfloat16,输出shape[b * s1 * n_q, kv_lora_rank]。

调用示例

mla_indexer_polog_quant

功能说明

MLA Indexer Prolog 模块将MLA Prolog和Lightning Indexer Prolog两个算子进行了更大范围的融合,实现了算子间的流水并行,提升了算子的性能。

函数原型

def mla_indexer_prolog_quant_compute(
    token_x, mla_w_dq, mla_w_uq_qr, mla_dequant_scale, mla_w_uk, mla_w_dkv_kr, mla_gamma_cq,
    mla_gamma_ckv, cos, sin, cache_index, mla_kv_cache, mla_kr_cache,
    mla_k_scale_cache, ip_w_qb_in, ip_w_qb_scale_in, ip_wk_in, ip_w_proj_in,
    ip_ln_gamma_k_in, ip_ln_beta_k_in, ip_hadamard_q_in, ip_hadamard_k_in,
    ip_k_cache, ip_k_cache_scale, mla_query_nope_out, mla_query_rope_out,
    mla_q_norm_out, mla_q_norm_scale_out, mla_kv_cache_out, mla_kr_cache_out,
    mla_k_scale_cache_out, ip_q_int8_out, ip_q_scale_out, ip_k_int8_out,
    ip_k_scale_out, ip_weights_out, mla_epsilon_cq, mla_epsilon_ckv,
    mla_cache_mode, mla_tile_config,
    ip_attrs, ip_configs, rope_cfg
):

参数说明

  • token_xTensor):公式中用于计算Query和Key的输入tensor,不支持非连续的 Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[t, h]。
  • mla_w_dqTensor):公式中用于计算Query的下采样权重矩阵WDQW^{DQ},不支持非连续的 Tensor。数据格式支持NZ,数据类型支持bfloat16,shape为[h, q_lora_rank]。
  • mla_w_uq_qrTensor):公式中用于计算Query的上采样权重矩阵WUQW^{UQ}和位置编码权重矩阵WQRW^{QR}。不支持非连续,数据格式支持NZ,数据类型支持int8,shape为[q_lora_rank, n_q*q_head_dim]。
  • mla_dequant_scaleTensor):用于MatmulQcQr矩阵乘后w_uq_qr反量化操作的per-channel参数,不支持非连续的 Tensor。数据格式支持ND,数据类型支持float,shape为[n_q*q_head_dim, 1]。
  • mla_w_ukTensor):公式中用于计算Key的上采样权重WUKW^{UK}。不支持非连续,数据格式支持ND,数据类型支持bfloat16,shape为[n_q, qk_nope_head_dim, kv_lora_rank]。
  • mla_w_dkv_krTensor):公式中用于计算Key的下采样权重矩阵WDKVW^{DKV}和位置编码权重矩阵WKRW^{KR}。不支持非连续,数据格式支持NZ,数据类型支持bfloat16,shape为[h, kv_lora_rank+rope_dim]。
  • mla_gamma_cqTensor):计算cQc^Q的RmsNorm公式中的γ\gamma参数。不支持非连续,数据格式支持ND,数据类型支持bfloat16,shape为[q_lora_rank]。
  • mla_gamma_ckvTensor):计算cKVc^{KV}的RmsNorm公式中的γ\gamma参数。不支持非连续,数据格式支持ND,数据类型支持bfloat16,shape为[kv_lora_rank]。
  • cosTensor):用于计算旋转位置编码的余弦参数矩阵。不支持非连续,数据格式支持ND,数据类型支持bfloat16,shape为[t, rope_dim]。
  • sinTensor):用于计算旋转位置编码的正弦参数矩阵。不支持非连续,数据格式支持ND,数据类型支持bfloat16,shape为[t, rope_dim]。
  • cache_indexTensor):用于存储kv_cache和kr_cache的索引。不支持非连续,数据格式支持ND,数据类型支持int64,shape为[T]。
  • mla_kv_cacheTensor):用于cache索引的aclTensor,计算结果原地更新(对应公式中的kCk^C),不支持非连续的 Tensor。数据格式支持ND,cache_mode为"PA_BSND",数据类型支持int8,cache_mode为"PA_BSND"、shape为[block_num, block_size, n_kv, kv_lora_rank]。
  • mla_kr_cacheTensor):用于key位置编码的cache,计算结果原地更新(对应公式中的kRk^R),不支持非连续的 Tensor。数据格式支持ND,cache_mode为"PA_BSND",数据类型支持bfloat16,cache_mode为"PA_BSND"、shape为[block_num, block_size, n_kv, rope_dim]。
  • mla_k_scale_cacheTensor):表示 key 反量化因子的缓存,必选参数,不支持非连续的 Tensor,数据格式支持 ND,cache_mode为"PA_BSND",数据类型支持float,shape为[block_num, block_size, n_kv, 4]。
  • ip_w_qb_inTensor):表示 query 的权重,必选参数,不支持非连续的Tensor,数据格式支持NZ,数据类型支持int8,shape为[q_lora_rank, idx_n_heads*idx_head_dim]。
  • ip_w_qb_scale_inTensor):表示 query 的权重反量化因子,必选参数,不支持非连续的Tensor,数据格式支持ND,数据类型支持float32,shape为[idx_n_heads*idx_head_dim, 1]。
  • ip_wk_inTensor):表示 key 的权重,必选参数,不支持非连续的Tensor,数据格式支持NZ,数据类型支持bfloat16,shape为[h, idx_head_dim]。
  • ip_w_proj_inTensor):表示 weights 的权重,必选参数,不支持非连续的Tensor,数据格式支持NZ,数据类型支持bfloat16,shape为[h, idx_n_heads]。
  • ip_ln_gamma_k_inTensor):表示 key 的 layernorm 缩放,必选参数,不支持非连续的Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[idx_head_dim]。
  • ip_ln_beta_k_inTensor):表示 key 的 layernorm 偏移,必选参数,不支持非连续的Tensor,数据格式支持ND,数据类型支持bfloat16,shape为[idx_head_dim]。
  • ip_hadamard_q_inTensor):表示用于 query Hadamard 变换的权重矩阵,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持bfloat16,shape为[idx_head_dim, idx_head_dim]。
  • ip_hadamard_k_inTensor):表示用于 key Hadamard 变换的权重矩阵,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持bfloat16,shape为[idx_head_dim, idx_head_dim]。
  • ip_k_cacheTensor):表示 key 的缓存,必选参数,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持int8,cache_mode为"PA_BSND",shape为[block_num, block_size, n_kv, idx_head_dim]。
  • ip_k_cache_scaleTensor):表示 key 反量化因子的缓存,必选参数,不支持非连续的 Tensor,数据格式支持 ND,cache_mode为"PA_BSND",数据类型支持float16,shape为[block_num, block_size, n_kv, 1]。
  • mla_epsilon_cqfloat):计算cQc^Q的RmsNorm公式中的ϵ\epsilon参数。用户未特意指定时,建议传入1e-05,仅支持double类型,默认值为1e-05。
  • mla_epsilon_ckvfloat):计算cKVc^{KV}的RmsNorm公式中的ϵ\epsilon参数。用户未特意指定时,建议传入1e-05,仅支持double类型,默认值为1e-05。
  • mla_cache_modestr):表示kv_cache的模式,支持"PA_BSND"
  • mla_tile_configclass MlaTileConfig):表示mla子图的tile切分配置。
  • ip_attrsclass IndexerPrologQuantAttr):lightning indexer prolog子图计算所需的属性值,包括layernorm_epsilon_k,layout_query,layout_key
  • ip.layernorm_epsilon_kfloat):表示 key layernorm 防除 0 系数,必选参数,数据类型支持float32
  • ip.layout_querystr):可选参数,用于标识输入query的数据排布格式,默认值"TND"。当前仅支持 "TND"。
  • ip.layout_keystr):可选参数,用于标识输入key的数据排布格式,默认值"PA_BSND"。当前仅支持 "PA_BSND"。
  • ip_configclass IndexerPrologQuantConfigs):表示ip子图的tile切分配置及动态分档配置。
  • rope_cfgclass RopeTileShapeConfig):表示rope子图的tile切分配置及动态分档配置。

返回值说明

  • mla_query_nope_outTensor):公式中Query的输出tensor(对应qNq^N),不支持非连续的 Tensor。数据格式支持ND,数据类型支持bfloat16,shape为[t, n_q, kv_lora_rank]。
  • mla_query_rope_outTensor):公式中Query位置编码的输出tensor(对应qRq^R),不支持非连续的 Tensor。数据格式支持ND,数据类型支持bfloat16,shape为[t, n_q, rope_dim]。
  • mla_q_norm_outTensor):对公式中Query位置编码的输出tensor做rmsnorm转换并量化后的输出,不支持非连续的 Tensor。数据格式支持ND,数据类型支持int8,shape为[t, q_lora_rank]。
  • mla_q_norm_scale_outTensor):对公式中Query位置编码的输出tensor做rmsnorm转换并量化后的反量化系数输出,不支持非连续的 Tensor。数据格式支持ND,数据类型支持float32,shape为[t, 1]。
  • mla_kv_cache_outTensor):Key输出到kv_cache中的tensor(对应kCk^C),不支持非连续的 Tensor。数据格式支持ND,cache_mode为"PA_BSND",数据类型支持int8,shape为[block_num, block_size, n_kv, kv_lora_rank]。
  • mla_kr_cache_outTensor):Key的位置编码输出到kr_cache中的tensor(对应kRk^R),不支持非连续的 Tensor。数据格式支持ND,cache_mode为"PA_BSND",数据类型支持bfloat16,shape为[block_num, block_size, n_kv, qk_rope_dim]。
  • mla_k_scale_cache_outTensor):Key做反量化后输出的反量化参数,不支持非连续的 Tensor。数据格式支持ND,cache_mode为"PA_BSND",数据类型支持float,shape为[block_num, block_size, n_kv, 4]。
  • ip_q_int8_outTensor):公式中 query 的输出 tensor,数据格式支持 ND,数据类型支持int8,shape为[t, idx_n_heads, idx_head_dim]。
  • ip_q_scale_outTensor):公式中 query 反量化因子的输出 tensor,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持float16,shape为[t, idx_n_heads, 1]。
  • ip_k_int8_outTensor):表示 key 的缓存(k_cache)的输出 tensor,不支持非连续的 Tensor,数据格式支持 ND,cache_mode为"PA_BSND",数据类型支持int8,shape为[block_num, block_size, n_kv, idx_head_dim]。
  • ip_k_scale_outTensor):表示 key 反量化因子的缓存的输出 tensor,不支持非连续的 Tensor,数据格式支持 ND,cache_mode为"PA_BSND",数据类型支持float16,shape为[block_num, block_size, n_kv, 1]。
  • ip_weights_outTensor):公式中 weights 的输出 tensor,不支持非连续的 Tensor,数据格式支持 ND,数据类型支持float16,shape为[t, idx_n_heads]。

调用示例

lightning indexer

功能说明

LightningIndexer基于一系列操作得到每一个 token 对应的 Top-kk 个位置。对于某个 token 对应的 Index Query Qindex∈Rg×dQ_{index}\in\R^{g\times d},给定上下文 Index Key Kindex∈RSk×d,W∈Rg×1K_{index}\in\R^{S_{k}\times d},W\in\R^{g\times 1},其中 gg 为 GQA 对应的 group size,dd 为每一个头的维度,SkS_{k} 是上下文的长度,LightningIndexer的具体计算公式如下:

Top-k{[1]1×g@[(W@[1]1×Sk)⊙ReLU(Qindex@KindexT)]}\text{Top-}k\left\{[1]_{1\times g}@\left[(W@[1]_{1\times S_{k}})\odot\text{ReLU}\left(Q_{index}@K_{index}^T\right)\right]\right\}

函数原型

def lightning_indexer_decode_compute(
    idx_query, idx_query_scale, idx_key_cache, idx_key_scale, idx_weight, act_seq_key, block_table, topk_res,
    unroll_list, configs, selected_count):

参数说明

  • idx_queryTensor):必选参数,不支持非连续,数据格式支持ND,数据类型支持int8,shape为[t, n_q, idx_head_dim]。
  • idx_query_scaleTensor):必选参数,表示idx_query的缩放系数,数据格式支持ND,数据类型支持float16,shape为[t, n_q, idx_head_dim]。
  • idx_key_cacheTensor):必选参数,不支持非连续,数据格式支持ND,数据类型支持int8,shape为[t, n_kv, idx_head_dim]。
  • idx_key_scaleTensor):必选参数,表示idx_key_cache的缩放系数,数据格式支持ND,数据类型支持float16,shape为[t, n_kv, idx_head_dim]
  • idx_weightTensor):必选参数,不支持非连续,数据格式支持ND,数据类型支持float16,支持输入shape[t, n_q]。
  • act_seq_keyTensor):必选参数,表示不同Batch中key的有效token数,数据类型支持int32, shape为[b]。
  • block_tableTensor):必选参数,表示PageAttention中KV存储使用的block映射表,数据格式支持ND,数据类型支持int32,shape为[b, ceilDiv(max(s2), block_size)], 其中max(s2)为s2中最大值, ceilDiv表示向上取整。
  • unroll_listList):非必选参数,表示多档位切分配置。
  • configsclass LightningIndexerConfigs):非必选参数,LightningIndexerConfigs配置结构体,表示tile切分配置和优化选项。
  • selected_countint):必选参数,topk选择数量,默认为2048。

返回值说明

  • topk_resTensor):公式中的输出,数据类型支持int32。数据格式支持ND,输出shape[t, n_kv, selected_count]。

调用示例