"""
MLA Indexer Prolog Quantization Module
This module implements fused MLA Prolog and Lightning Indexer Prolog computation
for DeepSeek V32 model. It combines both operators to enable pipeline parallelism
and improve overall performance.
Main Functions:
- mla_indexer_prolog_quant_p: Fused computation for prefill phase
- mla_indexer_prolog_quant_d: Fused computation for decode phase
Example:
See deepseekv32_mla_indexer_prolog_quant.py for usage examples.
"""
import math
from typing import List
from dataclasses import dataclass
import torch
import torch_npu
import pypto
from lightning_indexer_prolog_quant_impl import rope_3d, quant_layer_norm, prolog_quant, quant_rope_2d
from mla_prolog_quant_impl import pre_compute_2d, rms_norm, rope_3d_v2, rope_v2, MlaQuantInputs, k_nope_quant
L0M_INDEX = 0
L1M_INDEX = 1
L0K_INDEX = 2
L1K_INDEX = 3
L0N_INDEX = 4
L1N_INDEX = 5
SCATTER_DIM = -2
VEC_TILE_4 = 4
VEC_TILE_32 = 32
def mla_indexer_prolog_quant_compute(
token_x, mla_w_dq, mla_w_uq_qr, mla_dequant_scale, mla_w_uk, mla_w_dkv_kr, mla_gamma_cq,
mla_gamma_ckv, cos, sin, cache_index, mla_kv_cache, mla_kr_cache,
mla_k_scale_cache, ip_w_qb_in, ip_w_qb_scale_in, ip_wk_in, ip_w_proj_in,
ip_ln_gamma_k_in, ip_ln_beta_k_in, ip_hadamard_q_in, ip_hadamard_k_in,
ip_k_cache, ip_k_cache_scale, mla_query_nope_out, mla_query_rope_out,
mla_q_norm_out, mla_q_norm_scale_out, mla_kv_cache_out, mla_kr_cache_out,
mla_k_scale_cache_out, ip_q_int8_out, ip_q_scale_out, ip_k_int8_out,
ip_k_scale_out, ip_weights_out, mla_epsilon_cq, mla_epsilon_ckv,
mla_cache_mode, mla_tile_config,
ip_attrs, ip_configs, rope_cfg
):
dtype = token_x.dtype
h = token_x.shape[1]
n1 = mla_w_uk.shape[0]
q_lora_rank = mla_w_dq.shape[1]
qk_nope_head_dim = mla_w_uk.shape[1]
kv_lora_rank = mla_w_uk.shape[2]
qk_rope_head_dim = sin.shape[1]
head_num = ip_w_proj_in.shape[1]
head_dim = ip_hadamard_q_in.shape[0]
q_head_dim = qk_nope_head_dim + qk_rope_head_dim
t = token_x.shape[0]
quant_inputs = MlaQuantInputs()
k_cache_index_2d = pypto.reshape(cache_index, [t, 1], inplace=True)
w_qb_scale = pypto.reshape(ip_w_qb_scale_in, [1, head_num * head_dim], inplace=True)
gamma_2d = pypto.reshape(ip_ln_gamma_k_in, [1, ip_ln_gamma_k_in.shape[0]], inplace=True)
beta_2d = pypto.reshape(ip_ln_beta_k_in, [1, ip_ln_beta_k_in.shape[0]], inplace=True)
if mla_dequant_scale is not None:
dequant_scale_wuqr_reshape = pypto.reshape(mla_dequant_scale, [1, n1 * q_head_dim], inplace=True)
quant_inputs.dequant_scale_w_uq_qr = dequant_scale_wuqr_reshape
unroll_list = mla_tile_config.unroll_list
for bs_offset, unroll_length in pypto.loop_unroll(0, t, 1, name="MLA_BS_LOOP", idx_name="bs_offset",
unroll_list=unroll_list, ):
tile_bs = unroll_length
output_offset = [bs_offset, 0, 0]
pypto.set_vec_tile_shapes(tile_bs, 128)
x_view = pypto.view(token_x, [tile_bs, h], [bs_offset, 0])
q_kv = pre_compute_2d(x_view, mla_w_dq, mla_w_uq_qr, mla_w_dkv_kr, mla_gamma_cq, \
mla_epsilon_cq, quant_inputs, mla_tile_config)
q = q_kv[0]
kv_tmp = q_kv[1]
pypto.set_semantic_label("Assemble_qNorm")
q_norm = q_kv[2]
pypto.set_vec_tile_shapes(tile_bs, q_lora_rank)
pypto.assemble(q_norm, [bs_offset, 0], mla_q_norm_out)
q_norm_scale = q_kv[3]
pypto.set_vec_tile_shapes(tile_bs, 1)
pypto.assemble(q_norm_scale, [bs_offset, 0], mla_q_norm_scale_out)
q_tmp = pypto.reshape(q, [tile_bs, n1, q_head_dim])
pypto.set_semantic_label("Prepare_qNope")
q_nope = pypto.view(q_tmp, [tile_bs, n1, qk_nope_head_dim], [0, 0, 0])
tile_shape = [min(16, tile_bs), 32, qk_nope_head_dim]
pypto.set_vec_tile_shapes(*tile_shape)
q_nope_trans = pypto.transpose(q_nope, 0, 1)
m = mla_tile_config.m_tile
pypto.set_semantic_label("Matmul_qNope_wUk")
pypto.set_cube_tile_shapes([m, m], [128, 128], [128, 128])
q_nope_new = pypto.matmul(q_nope_trans, mla_w_uk, dtype)
tile_shape = [1, min(32, tile_bs), kv_lora_rank]
pypto.set_vec_tile_shapes(*tile_shape)
q_nope_new_trans = pypto.transpose(q_nope_new, 0, 1)
pypto.set_semantic_label("Assemble_queryOut")
pypto.set_vec_tile_shapes(mla_tile_config.q_vec_tile0, mla_tile_config.q_vec_tile1, 128)
pypto.assemble(q_nope_new_trans, output_offset, mla_query_nope_out)
if tile_bs >= 128:
pypto.set_vec_tile_shapes(mla_tile_config.q_vec_tile0, mla_tile_config.q_vec_tile1, 64)
q_pe_view = pypto.view(q_tmp, [tile_bs, n1, qk_rope_head_dim], [0, 0, qk_nope_head_dim])
cos_2d_view = pypto.view(cos, [tile_bs, qk_rope_head_dim], [bs_offset, 0])
sin_2d_view = pypto.view(sin, [tile_bs, qk_rope_head_dim], [bs_offset, 0])
pypto.set_semantic_label("Rope_qRope")
q_rope_view = rope_3d_v2(q_pe_view, cos_2d_view, sin_2d_view)
pypto.set_semantic_label("Assemble_qRope")
pypto.set_vec_tile_shapes(mla_tile_config.q_vec_tile0, mla_tile_config.q_vec_tile1, 64)
pypto.assemble(q_rope_view, output_offset, mla_query_rope_out)
pypto.set_vec_tile_shapes(mla_tile_config.k_vec_tile0, mla_tile_config.k_vec_tile1)
pypto.set_semantic_label("RotaryPosEmb")
k_pe_view = pypto.view(kv_tmp, [tile_bs, qk_rope_head_dim], [0, kv_lora_rank])
k_rope_2d = rope_v2(k_pe_view, cos_2d_view, sin_2d_view, rope_cfg)
compressed_kv = pypto.view(kv_tmp, [tile_bs, kv_lora_rank], [0, 0])
pypto.set_semantic_label("RmsNorm_compressedkv")
pypto.set_vec_tile_shapes(mla_tile_config.k_vec_tile0, mla_tile_config.k_vec_tile1)
k_nope = rms_norm(compressed_kv, mla_gamma_ckv, mla_epsilon_ckv)
pypto.set_semantic_label("Quant_knope")
pypto.set_vec_tile_shapes(32, kv_lora_rank)
k_nope_split = pypto.reshape(k_nope, [tile_bs, 4, kv_lora_rank // 4])
pypto.set_vec_tile_shapes(32, 4, kv_lora_rank // 4)
k_nope_quant_res = k_nope_quant(k_nope_split)
k_nope_quant_tensor = k_nope_quant_res[0]
k_nope_scale = k_nope_quant_res[1]
pypto.set_vec_tile_shapes(32, 4, kv_lora_rank // 4)
k_nope_2d = pypto.reshape(k_nope_quant_tensor, [tile_bs, kv_lora_rank])
k_scale_2d = pypto.reshape(k_nope_scale, [tile_bs, 4])
k_rope_4d = pypto.reshape(k_rope_2d, [tile_bs, 1, 1, qk_rope_head_dim], inplace=True)
k_nope_4d = pypto.reshape(k_nope_2d, [tile_bs, 1, 1, kv_lora_rank], inplace=True)
k_scale_4d = pypto.reshape(k_scale_2d, [tile_bs, 1, 1, 4], inplace=True)
index = pypto.view(k_cache_index_2d, [tile_bs, 1], [bs_offset, 0])
pypto.set_semantic_label("ScatterUpdate_krCache")
pypto.set_vec_tile_shapes(32, 1, 1, qk_rope_head_dim)
mla_kr_cache_out[:] = pypto.scatter_update(mla_kr_cache, -2, index, k_rope_4d)
pypto.set_semantic_label("ScatterUpdate_kvCache")
pypto.set_vec_tile_shapes(32, 1, 1, kv_lora_rank)
mla_kv_cache_out[:] = pypto.scatter_update(mla_kv_cache, -2, index, k_nope_4d)
pypto.set_semantic_label("ScatterUpdate_kScaleCache")
pypto.set_vec_tile_shapes(32, 1, 1, 4)
mla_k_scale_cache_out[:] = pypto.scatter_update(mla_k_scale_cache, -2, index, k_scale_4d)
q_linear = ip_configs.q_linear
q_hd = ip_configs.q_hd
pypto.set_semantic_label("Query-Linear")
pypto.set_cube_tile_shapes([q_linear[L0M_INDEX], q_linear[L1M_INDEX]],
[q_linear[L0K_INDEX], q_linear[L1K_INDEX]],
[q_linear[L0N_INDEX], q_linear[L1N_INDEX]])
q_s32 = pypto.matmul(q_norm, ip_w_qb_in, pypto.DT_INT32)
pypto.set_semantic_label("Query-Dequant")
pypto.set_vec_tile_shapes(ip_configs.t_sub_tile, head_num * head_dim // ip_configs.chunk_size)
q_f32 = pypto.cast(q_s32, pypto.DT_FP32)
q_f32 = q_f32 * q_norm_scale
q_f32 = q_f32 * w_qb_scale
q_cast = pypto.cast(q_f32, dtype)
q_bf16 = pypto.reshape(q_cast, [tile_bs, head_num, head_dim], valid_shape=[tile_bs, head_num, head_dim])
q_rope = pypto.view(q_bf16, [tile_bs, head_num, qk_rope_head_dim], [0, 0, 0],
valid_shape=[tile_bs, head_num, qk_rope_head_dim])
q_nope = pypto.view(q_bf16, [tile_bs, head_num, head_dim - qk_rope_head_dim], [0, 0, qk_rope_head_dim],
valid_shape=[tile_bs, head_num, head_dim - qk_rope_head_dim])
q_roped = rope_3d(q_rope, cos_2d_view, sin_2d_view, ip_configs)
pypto.set_vec_tile_shapes(ip_configs.t_sub_tile, head_num // ip_configs.chunk_size, head_dim)
q_nope = pypto.cast(pypto.cast(q_nope, pypto.DT_FP32), q_bf16.dtype)
q_cat = pypto.concat([q_roped, q_nope], -1)
hadamard_q = pypto.reshape(ip_hadamard_q_in, [1, head_dim, head_dim], valid_shape=[1, head_dim, head_dim])
pypto.set_semantic_label("Query-Hadamard")
cur_max_unroll = 32
q_hd_m_tile = cur_max_unroll if tile_bs < cur_max_unroll else q_hd[L0M_INDEX]
pypto.set_cube_tile_shapes([q_hd_m_tile, q_hd_m_tile], [q_hd[L0K_INDEX], q_hd[L1K_INDEX]],
[q_hd[L0N_INDEX], q_hd[L1N_INDEX]])
q_hadamard = pypto.matmul(q_cat, hadamard_q, dtype)
pypto.set_semantic_label("Query-Quant")
pypto.set_vec_tile_shapes(ip_configs.t_sub_tile, head_num // ip_configs.chunk_size, head_dim)
q_res = prolog_quant(q_hadamard)
q_scale = pypto.cast(q_res[1], pypto.DT_FP16)
pypto.assemble(q_res[0], [bs_offset, 0, 0], ip_q_int8_out)
pypto.assemble(q_scale, [bs_offset, 0, 0], ip_q_scale_out)
k_linear = ip_configs.k_linear
pypto.set_semantic_label("Key-Linear")
pypto.set_cube_tile_shapes([k_linear[L0M_INDEX], k_linear[L1M_INDEX]],
[k_linear[L0K_INDEX], k_linear[L1K_INDEX]],
[k_linear[L0N_INDEX], k_linear[L1N_INDEX]])
k = pypto.matmul(x_view, ip_wk_in, pypto.DT_FP32)
if tile_bs <= 32:
pypto.set_vec_tile_shapes(min(tile_bs, VEC_TILE_4), head_dim)
else:
pypto.set_vec_tile_shapes(min(tile_bs, VEC_TILE_32), head_dim)
k_bf16 = pypto.cast(quant_layer_norm(k, gamma_2d, beta_2d, -1, ip_attrs.eps), dtype)
k_rope = pypto.view(k_bf16, [tile_bs, qk_rope_head_dim], [0, 0], valid_shape=[tile_bs, qk_rope_head_dim])
k_nope = pypto.view(k_bf16, [tile_bs, head_dim - qk_rope_head_dim], [0, qk_rope_head_dim],
valid_shape=[tile_bs, head_dim - qk_rope_head_dim])
k_roped = quant_rope_2d(k_rope, cos_2d_view, sin_2d_view)
pypto.set_vec_tile_shapes(tile_bs, head_dim)
k_nope = pypto.cast(pypto.cast(k_nope, pypto.DT_FP32), k_bf16.dtype)
k_concat = pypto.concat([k_roped, k_nope], -1)
pypto.set_semantic_label("Key-Hadamard")
hadamard_k = pypto.matmul(k_concat, ip_hadamard_k_in, dtype)
pypto.set_semantic_label("Key-Quant")
k_res = prolog_quant(hadamard_k)
k_cache_4d = pypto.reshape(k_res[0], [tile_bs, 1, 1, head_dim], valid_shape=[tile_bs, 1, 1, head_dim])
k_scale_4d = pypto.reshape(pypto.cast(k_res[1], pypto.DT_FP16), [tile_bs, 1, 1, 1],
valid_shape=[tile_bs, 1, 1, 1])
pypto.set_vec_tile_shapes(tile_bs, 1, 1, head_dim)
ip_k_int8_out.move(pypto.scatter_update(ip_k_cache, SCATTER_DIM, index, k_cache_4d))
ip_k_scale_out.move(pypto.scatter_update(ip_k_cache_scale, SCATTER_DIM, index, k_scale_4d))
pypto.set_semantic_label("Weight-Linear")
w_linear = ip_configs.w_linear
pypto.set_cube_tile_shapes([w_linear[L0M_INDEX], w_linear[L1M_INDEX]],
[w_linear[L0K_INDEX], w_linear[L1K_INDEX]],
[w_linear[L0N_INDEX], w_linear[L1N_INDEX]])
pypto.set_vec_tile_shapes(tile_bs, head_num)
weights = pypto.cast(pypto.matmul(x_view, ip_w_proj_in, dtype), pypto.DT_FP32)
weights = pypto.mul(weights, 1.0 / (math.sqrt(head_num) * math.sqrt(head_dim)))
weights_f16 = pypto.cast(weights, pypto.DT_FP16)
pypto.assemble(weights_f16, [bs_offset, 0], ip_weights_out)
@pypto.frontend.jit(
pass_options={
"cube_l1_reuse_setting": {-1: 4},
},
runtime_options={"stitch_function_max_num": 128,
"device_sched_mode": 2}
)
def mla_indexer_prolog_quant_p(
token_x: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
mla_w_dq: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16, format=pypto.TileOpFormat.TILEOP_NZ),
mla_w_uq_qr: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT8, format=pypto.TileOpFormat.TILEOP_NZ),
mla_dequant_scale: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
mla_w_uk: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
mla_w_dkv_kr: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16, format=pypto.TileOpFormat.TILEOP_NZ),
mla_gamma_cq: pypto.Tensor([pypto.STATIC], pypto.DT_BF16),
mla_gamma_ckv: pypto.Tensor([pypto.STATIC], pypto.DT_BF16),
cos: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
sin: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
cache_index: pypto.Tensor([pypto.DYNAMIC], pypto.DT_INT64),
mla_kv_cache: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_INT8),
mla_kr_cache: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
mla_k_scale_cache: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
ip_w_qb_in: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT8, format=pypto.TileOpFormat.TILEOP_NZ),
ip_w_qb_scale_in: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
ip_wk_in: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16, format=pypto.TileOpFormat.TILEOP_NZ),
ip_w_proj_in: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16, format=pypto.TileOpFormat.TILEOP_NZ),
ip_ln_gamma_k_in: pypto.Tensor([pypto.STATIC], pypto.DT_BF16),
ip_ln_beta_k_in: pypto.Tensor([pypto.STATIC], pypto.DT_BF16),
ip_hadamard_q_in: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
ip_hadamard_k_in: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
ip_k_cache: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_INT8),
ip_k_cache_scale: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
mla_query_nope_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
mla_query_rope_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
mla_kv_cache_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_INT8),
mla_kr_cache_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
mla_k_scale_cache_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
ip_q_int8_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_INT8),
ip_q_scale_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
ip_k_int8_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_INT8),
ip_k_scale_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
ip_weights_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_FP16),
mla_epsilon_cq,
mla_epsilon_ckv,
mla_cache_mode,
mla_tile_config,
ip_attrs,
ip_configs,
rope_cfg
):
"""Fused MLA and Indexer Prolog quantization for prefill phase.
Combines MLA Prolog and Lightning Indexer Prolog computations in a single
fused operator for prefill phase. This enables pipeline parallelism and
reduces memory transfers between operators.
The computation flow:
1. MLA Prolog: Computes MLA query, key, and value projections
2. Indexer Prolog: Uses MLA's q_norm output to compute indexer query, key, and weights
Args:
token_x: Input token tensor, shape (t, h), dtype BF16
mla_w_dq: MLA down-projection weight for query, NZ format
mla_w_uq_qr: MLA up-projection weight for query and RoPE, NZ format
mla_dequant_scale: MLA dequantization scale, FP32
mla_w_uk: MLA up-projection weight for key, BF16
mla_w_dkv_kr: MLA down-projection weight for key-value and RoPE, NZ format
mla_gamma_cq: MLA RMSNorm scale for query, BF16
mla_gamma_ckv: MLA RMSNorm scale for key-value, BF16
cos: Cosine values for RoPE, BF16
sin: Sine values for RoPE, BF16
cache_index: Cache index for scatter update, INT64
mla_kv_cache: MLA key-value cache input/output, INT8
mla_kr_cache: MLA key RoPE cache input/output, BF16
mla_k_scale_cache: MLA key scale cache input/output, FP16
ip_w_qb_in: Indexer query projection weight, INT8, NZ format
ip_w_qb_scale_in: Indexer query weight dequantization scale, FP32
ip_wk_in: Indexer key projection weight, BF16, NZ format
ip_w_proj_in: Indexer weight projection matrix, BF16, NZ format
ip_ln_gamma_k_in: Indexer LayerNorm scale for key, BF16
ip_ln_beta_k_in: Indexer LayerNorm shift for key, BF16
ip_hadamard_q_in: Indexer Hadamard matrix for query, BF16
ip_hadamard_k_in: Indexer Hadamard matrix for key, BF16
ip_k_cache: Indexer key cache input/output, INT8
ip_k_cache_scale: Indexer key cache scale input/output, FP16
mla_query_nope_out: Output MLA query without RoPE, BF16
mla_query_rope_out: Output MLA query with RoPE, BF16
mla_kv_cache_out: Output MLA key-value cache
mla_kr_cache_out: Output MLA key RoPE cache
mla_k_scale_cache_out: Output MLA key scale cache
ip_q_int8_out: Output indexer quantized query, INT8
ip_q_scale_out: Output indexer query quantization scale, FP16
ip_k_int8_out: Output indexer key cache
ip_k_scale_out: Output indexer key cache scale
ip_weights_out: Output indexer weights, FP16
mla_epsilon_cq: MLA RMSNorm epsilon for query
mla_epsilon_ckv: MLA RMSNorm epsilon for key-value
mla_cache_mode: MLA cache mode
mla_tile_config: MlaTileConfig object for MLA computation
ip_attrs: IndexerPrologQuantAttr object for indexer computation
ip_configs: IndexerPrologQuantip_configs object for indexer computation
rope_cfg: RopeTileShapeConfig object for RoPE computation
Note:
The function creates intermediate tensors (mla_q_norm_out, mla_q_norm_scale_out)
to pass data from MLA Prolog to Indexer Prolog. Pipeline parallelism is
enabled through device_sched_mode=2.
"""
t = token_x.shape[0]
actual_q_lora_rank = ip_w_qb_in.shape[0]
mla_q_norm_out = pypto.Tensor((t, actual_q_lora_rank), pypto.DT_INT8)
mla_q_norm_scale_out = pypto.Tensor((t, 1), pypto.DT_FP32)
mla_indexer_prolog_quant_compute(
token_x, mla_w_dq, mla_w_uq_qr, mla_dequant_scale, mla_w_uk, mla_w_dkv_kr, mla_gamma_cq,
mla_gamma_ckv, cos, sin, cache_index, mla_kv_cache, mla_kr_cache,
mla_k_scale_cache, ip_w_qb_in, ip_w_qb_scale_in, ip_wk_in, ip_w_proj_in,
ip_ln_gamma_k_in, ip_ln_beta_k_in, ip_hadamard_q_in, ip_hadamard_k_in,
ip_k_cache, ip_k_cache_scale, mla_query_nope_out, mla_query_rope_out,
mla_q_norm_out, mla_q_norm_scale_out,
mla_kv_cache_out, mla_kr_cache_out,
mla_k_scale_cache_out, ip_q_int8_out, ip_q_scale_out, ip_k_int8_out,
ip_k_scale_out, ip_weights_out, mla_epsilon_cq, mla_epsilon_ckv,
mla_cache_mode, mla_tile_config,
ip_attrs, ip_configs, rope_cfg
)
@pypto.frontend.jit(
pass_options={
"cube_l1_reuse_setting": {-1: 4, 0: 1, 1: 1, 2: 1},
"cube_nbuffer_setting": {-1: 6, 0: 1, 1: 1, 2: 1},
},
runtime_options={"device_sched_mode": 2}
)
def mla_indexer_prolog_quant_d(
token_x: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
mla_w_dq: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16, format=pypto.TileOpFormat.TILEOP_NZ),
mla_w_uq_qr: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT8, format=pypto.TileOpFormat.TILEOP_NZ),
mla_dequant_scale: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
mla_w_uk: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
mla_w_dkv_kr: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16, format=pypto.TileOpFormat.TILEOP_NZ),
mla_gamma_cq: pypto.Tensor([pypto.STATIC], pypto.DT_BF16),
mla_gamma_ckv: pypto.Tensor([pypto.STATIC], pypto.DT_BF16),
cos: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
sin: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
cache_index: pypto.Tensor([pypto.DYNAMIC], pypto.DT_INT64),
mla_kv_cache: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_INT8),
mla_kr_cache: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
mla_k_scale_cache: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
ip_w_qb_in: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT8, format=pypto.TileOpFormat.TILEOP_NZ),
ip_w_qb_scale_in: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
ip_wk_in: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16, format=pypto.TileOpFormat.TILEOP_NZ),
ip_w_proj_in: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16, format=pypto.TileOpFormat.TILEOP_NZ),
ip_ln_gamma_k_in: pypto.Tensor([pypto.STATIC], pypto.DT_BF16),
ip_ln_beta_k_in: pypto.Tensor([pypto.STATIC], pypto.DT_BF16),
ip_hadamard_q_in: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
ip_hadamard_k_in: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
ip_k_cache: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_INT8),
ip_k_cache_scale: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
mla_query_nope_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
mla_query_rope_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
mla_kv_cache_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_INT8),
mla_kr_cache_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
mla_k_scale_cache_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
ip_q_int8_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_INT8),
ip_q_scale_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
ip_k_int8_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_INT8),
ip_k_scale_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
ip_weights_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_FP16),
mla_epsilon_cq,
mla_epsilon_ckv,
mla_cache_mode,
mla_tile_config,
ip_attrs,
ip_configs,
rope_cfg
):
"""Fused MLA and Indexer Prolog quantization for decode phase.
Combines MLA Prolog and Lightning Indexer Prolog computations in a single
fused operator for decode phase. This enables pipeline parallelism and
reduces memory transfers between operators.
The computation flow:
1. MLA Prolog: Computes MLA query, key, and value projections
2. Indexer Prolog: Uses MLA's q_norm output to compute indexer query, key, and weights
Args:
token_x: Input token tensor, shape (t, h), dtype BF16
mla_w_dq: MLA down-projection weight for query, NZ format
mla_w_uq_qr: MLA up-projection weight for query and RoPE, NZ format
mla_dequant_scale: MLA dequantization scale, FP32
mla_w_uk: MLA up-projection weight for key, BF16
mla_w_dkv_kr: MLA down-projection weight for key-value and RoPE, NZ format
mla_gamma_cq: MLA RMSNorm scale for query, BF16
mla_gamma_ckv: MLA RMSNorm scale for key-value, BF16
cos: Cosine values for RoPE, BF16
sin: Sine values for RoPE, BF16
cache_index: Cache index for scatter update, INT64
mla_kv_cache: MLA key-value cache input/output, INT8
mla_kr_cache: MLA key RoPE cache input/output, BF16
mla_k_scale_cache: MLA key scale cache input/output, FP16
ip_w_qb_in: Indexer query projection weight, INT8, NZ format
ip_w_qb_scale_in: Indexer query weight dequantization scale, FP32
ip_wk_in: Indexer key projection weight, BF16, NZ format
ip_w_proj_in: Indexer weight projection matrix, BF16, NZ format
ip_ln_gamma_k_in: Indexer LayerNorm scale for key, BF16
ip_ln_beta_k_in: Indexer LayerNorm shift for key, BF16
ip_hadamard_q_in: Indexer Hadamard matrix for query, BF16
ip_hadamard_k_in: Indexer Hadamard matrix for key, BF16
ip_k_cache: Indexer key cache input/output, INT8
ip_k_cache_scale: Indexer key cache scale input/output, FP16
mla_query_nope_out: Output MLA query without RoPE, BF16
mla_query_rope_out: Output MLA query with RoPE, BF16
mla_kv_cache_out: Output MLA key-value cache
mla_kr_cache_out: Output MLA key RoPE cache
mla_k_scale_cache_out: Output MLA key scale cache
ip_q_int8_out: Output indexer quantized query, INT8
ip_q_scale_out: Output indexer query quantization scale, FP16
ip_k_int8_out: Output indexer key cache
ip_k_scale_out: Output indexer key cache scale
ip_weights_out: Output indexer weights, FP16
mla_epsilon_cq: MLA RMSNorm epsilon for query
mla_epsilon_ckv: MLA RMSNorm epsilon for key-value
mla_cache_mode: MLA cache mode
mla_tile_config: MlaTileConfig object for MLA computation
ip_attrs: IndexerPrologQuantAttr object for indexer computation
ip_configs: IndexerPrologQuantip_configs object for indexer computation
rope_cfg: RopeTileShapeConfig object for RoPE computation
Note:
The function creates intermediate tensors (mla_q_norm_out, mla_q_norm_scale_out)
to pass data from MLA Prolog to Indexer Prolog. Pipeline parallelism is
enabled through device_sched_mode=2.
"""
t = token_x.shape[0]
actual_q_lora_rank = ip_w_qb_in.shape[0]
mla_q_norm_out = pypto.Tensor((t, actual_q_lora_rank), pypto.DT_INT8)
mla_q_norm_scale_out = pypto.Tensor((t, 1), pypto.DT_FP32)
mla_indexer_prolog_quant_compute(
token_x, mla_w_dq, mla_w_uq_qr, mla_dequant_scale, mla_w_uk, mla_w_dkv_kr, mla_gamma_cq,
mla_gamma_ckv, cos, sin, cache_index, mla_kv_cache, mla_kr_cache,
mla_k_scale_cache, ip_w_qb_in, ip_w_qb_scale_in, ip_wk_in, ip_w_proj_in,
ip_ln_gamma_k_in, ip_ln_beta_k_in, ip_hadamard_q_in, ip_hadamard_k_in,
ip_k_cache, ip_k_cache_scale, mla_query_nope_out, mla_query_rope_out,
mla_q_norm_out, mla_q_norm_scale_out,
mla_kv_cache_out, mla_kr_cache_out,
mla_k_scale_cache_out, ip_q_int8_out, ip_q_scale_out, ip_k_int8_out,
ip_k_scale_out, ip_weights_out, mla_epsilon_cq, mla_epsilon_ckv,
mla_cache_mode, mla_tile_config,
ip_attrs, ip_configs, rope_cfg
)