"""
"""
from dataclasses import dataclass
import math
from typing import List, Tuple
import torch
from torch._dynamo import allow_in_graph
from torch._subclasses.fake_tensor import FakeTensor
import torch_npu
import pypto
"""
MLA Prolog Quantization Module
This module implements MLA (Multi-head Latent Attention) Prolog quantization
for DeepSeek V4 model. It converts hidden states to query, key, and value
projections with support for quantization and RoPE (Rotary Position Embedding).
Main Functions:
- mla_prolog_quant_compute: Core MLA prolog computation with quantization
- pre_compute_2d: Pre-computation for query and key-value projections
- rms_norm: RMS normalization implementation
- quant: Quantization function with symmetry and smooth factor support
- dequant: Dequantization function
- rope_v2: 2D RoPE implementation
- rope_3d_v2: 3D RoPE implementation
- k_nope_quant: Key quantization function
Example:
See test_mla_prolog_quant_v4.py for usage examples.
"""
SHAPE_DIM_2 = 2
SHAPE_DIM_3 = 3
NUM_0 = 0
NUM_1 = 1
NUM_2 = 2
NUM_3 = 3
NUM_4096 = 4096
NUM_512 = 512
TILE_CUBE_DIM = 6
Q_PARAM_DIM = 2
NZ_DIM = 4
COS_SIN_DIM = 2
L0M_INDEX = 0
L1M_INDEX = 1
L0K_INDEX = 2
L1K_INDEX = 3
L0N_INDEX = 4
L1N_INDEX = 5
SCATTER_DIM = -2
NZ_FIRST_DIM = 16
NZ_B8_C0 = 32
NZ_B16_C0 = 16
VEC_TILE_256 = 256
VEC_TILE_128 = 128
VEC_TILE_64 = 64
VEC_TILE_8 = 8
VEC_TILE_4 = 4
VEC_TILE_32 = 32
@dataclass
class MlaPrologV4Output:
x: torch.tensor
wq_a: torch.tensor
wq_b: torch.tensor
wkv: torch.tensor
rmsnorm_gamma_cq: torch.tensor
rmsnorm_gamma_ckv: torch.tensor
cos: torch.tensor
sin: torch.tensor
@dataclass
class MlaPrologV4Output:
q: torch.tensor
kv: torch.tensor
qr: torch.tensor
@dataclass
class MlaPrologV4Attrs:
eps: float
layout_query: str
layout_key: str
@dataclass
class MlaPrologV4Configs:
unroll_list: List[int]
cube_l1_reuse_setting: dict[int, int]
mg_copyin_upper_bound: int
pg_upper_bound: int
block_size: int
t_sub_tile: int
chunk_size: int
def check_input_output_shape_dtype(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv,
output_q_data, output_kv_data, output_qr_data):
assert token_x.size(1) == 4096 and token_x.dim() == 2, f"expected token_x dim num 2, token_x axis1 4096"
assert wq_a.dim() == 2 and wq_a.size(0) == 4096 and wq_a.size(1) == 1024, \
f"expected wq_a dim num 2 residual axis0 4096, wq_a axis1 1024"
assert wq_b.dim() == 2 and wq_b.size(0) == 1024 and wq_b.size(1) == 32768, \
f"expected wq_b dim num 2, wq_b axis0 1024, wq_b axis1 32768"
assert wkv.dim() == 2 and wkv.size(0) == 4096 and wkv.size(1) == 512, \
f"expected wkv dim num 2, wkv axis0 4096, wkv axis1 512"
assert rope_cos.dim() == 2 and rope_cos.size(1) == 64, \
f"expected rope_cos dim num 2, rope_cos axis1 64"
assert rope_sin.dim() == 2 and rope_sin.size(1) == 64, \
f"expected rope_sin dim num 2, rope_sin axis1 64"
assert gamma_cq.dim() == 1 and gamma_cq.size(0) == 1024, \
f"expected gamma_cq dim num 1, gamma_cq axis0 1024"
assert gamma_ckv.dim() == 1 and gamma_ckv.size(0) == 512, \
f"expected gamma_ckv dim num 1, gamma_ckv axis0 512"
assert output_q_data.dim() == 3 and output_q_data.size(1) == 64 and output_q_data.size(2) == 512, \
f"expected output_q_data dim num 3, output_q_data axis1 64, output_q_data axis2 512"
assert output_kv_data.dim() == 2 and output_kv_data.size(1) == 512, \
f"expected output_kv_data dim num 2, output_kv_data axis1 512"
assert output_qr_data.dim() == 2 and output_qr_data.size(1) == 1024, \
f"expected output_qr_data dim num 2, output_qr_data axis1 4096"
assert token_x.dtype == torch.bfloat16, f"token_x.dtype is {token_x.dtype}, expected torch.bfloat16"
assert wq_a.dtype == torch.bfloat16, f"wq_a.dtype is {wq_a.dtype}, expected torch.bfloat16"
assert wq_b.dtype == torch.bfloat16, f"wq_b.dtype is {wq_b.dtype}, expected torch.bfloat16"
assert wkv.dtype == torch.bfloat16, f"wkv.dtype is {wkv.dtype}, expected torch.bfloat16"
assert rope_cos.dtype == torch.bfloat16, f"rope_cos.dtype is {rope_cos.dtype}, expected torch.bfloat16"
assert rope_sin.dtype == torch.bfloat16, f"rope_sin.dtype is {rope_sin.dtype}, expected torch.bfloat16"
assert gamma_cq.dtype == torch.bfloat16, f"gamma_cq.dtype is {gamma_cq.dtype}, expected torch.bfloat16"
assert gamma_ckv.dtype == torch.bfloat16, f"gamma_ckv.dtype is {gamma_ckv.dtype}, expected torch.bfloat16"
assert output_q_data.dtype == torch.bfloat16, f"output_q_data.dtype is {output_q_data.dtype}, \
expected torch.bfloat16"
assert output_kv_data.dtype == torch.bfloat16, f"output_kv_data.dtype is {output_kv_data.dtype}, \
expected torch.bfloat16"
assert output_qr_data.dtype == torch.bfloat16, f"output_qr_data.dtype is {output_qr_data.dtype}, \
expected torch.bfloat16"
def rms_norm(input_tensor: pypto.Tensor, epsilon: float) -> pypto.Tensor:
"""Compute RMS (Root Mean Square) normalization.
Applies RMS normalization to the input tensor. RMS normalization is similar
to LayerNorm but uses root mean square instead of standard deviation.
Formula: output = gamma * input / sqrt(mean(input^2) + epsilon)
Args:
input_tensor: Input tensor to normalize
epsilon: Small constant added to variance to avoid division by zero
Returns:
Normalized tensor with the same shape as input
Note:
The normalization is performed along the last dimension.
Computation is done in FP32 for numerical stability.
"""
input_fp32 = pypto.cast(input_tensor, pypto.DT_FP32)
dim = len(input_tensor.shape)
y = pypto.mul(input_fp32, input_fp32)
y = pypto.mul(y, 1.0 / input_tensor.shape[dim - 1])
y = pypto.sum(y, -1, keepdim=True)
y = pypto.add(y, epsilon)
y = pypto.sqrt(y)
ones_vector = pypto.full(y.shape, 1.0, pypto.DT_FP32)
y = pypto.div(ones_vector, y)
y = pypto.mul(input_fp32, y)
return y
def rotate_half(input_tensor: pypto.Tensor) -> pypto.Tensor:
"""Rotate half of the tensor dimensions for RoPE computation.
Splits the last dimension in half and applies rotation transformation:
[-x2, x1] where x1 is the first half and x2 is the second half.
This is a key component of RoPE (Rotary Position Embedding).
Args:
input_tensor: Input tensor with last dimension divisible by 2
Returns:
Rotated tensor with same shape as input
Raises:
AssertionError: If input dimension is less than 1 or last dimension
is not divisible by 2
"""
shape = input_tensor.shape
shape_size = len(shape)
new_shape = list(shape)
new_shape[shape_size - 1] //= 2
offset1 = [0] * shape_size
offset2 = [0] * shape_size
offset2[shape_size - 1] = new_shape[shape_size - 1]
x1 = pypto.view(input_tensor, new_shape, offset1)
x2 = pypto.view(input_tensor, new_shape, offset2)
return pypto.concat([x2 * (-1.0), x1 + 0.0], -1)
def rope_2d(
x: pypto.Tensor, cos: pypto.Tensor, sin: pypto.Tensor) -> pypto.Tensor:
"""Apply 2D Rotary Position Embedding (RoPE) version 2.
Implements RoPE transformation for 2D tensors with optimized tiling.
The function reshapes and transposes the input before applying rotation.
Args:
x: Input tensor of shape (seq_size, d_r)
cos: Cosine values for RoPE, shape (seq_size, d_r)
sin: Sine values for RoPE, shape (seq_size, d_r)
tile_config: RopeTileShapeConfig object containing tiling parameters:
- two_dim: Tile shape for 2D operations
- three_dim: Tile shape for 3D reshape operations
Returns:
Tensor with RoPE applied, same shape as input x
Note:
The function performs reshape and transpose operations before applying
rotation to optimize memory access patterns.
"""
assert len(x.shape) == 2 and len(cos.shape) == 2 and len(sin.shape) == 2
input_dtype = x.dtype
pypto.set_vec_tile_shapes(1, 64, 128)
y = pypto.clone(x)
y_cast = pypto.cast(y, pypto.DT_FP32)
x_view = pypto.reshape(x, [x.shape[0], x.shape[1]//2, 2])
x_trans = pypto.transpose(x_view, 1, 2)
x_re_second = pypto.reshape(x_trans, x.shape)
x_t = rotate_half(x_re_second)
x_new = pypto.reshape(x_t, [x.shape[0], 2, x.shape[1]//2])
x_new_trans = pypto.transpose(x_new, 1, 2)
x_new_r = pypto.reshape(x_new_trans, x.shape)
x_new_cast = pypto.cast(x_new_r, pypto.DT_FP32)
pypto.set_vec_tile_shapes(1, 64)
cast_cos = pypto.cast(cos, pypto.DT_FP32)
cast_sin = pypto.cast(sin, pypto.DT_FP32)
pypto.set_vec_tile_shapes(1, 64, 64)
x_embed = y_cast * cast_cos + x_new_cast * cast_sin
return pypto.cast(x_embed, input_dtype)
def rope_3d(x: pypto.Tensor, cos: pypto.Tensor, sin: pypto.Tensor) -> pypto.Tensor:
"""Apply inverse 3D Rotary Position Embedding.
"""
assert (len(x.shape) == SHAPE_DIM_3 and len(cos.shape) == SHAPE_DIM_2 and len(sin.shape) == SHAPE_DIM_2)
pypto.set_vec_tile_shapes(1, 64)
cast_cos = pypto.cast(cos, pypto.DataType.DT_FP32)
cast_sin = pypto.cast(sin, pypto.DataType.DT_FP32)
pypto.set_vec_tile_shapes(1, 64, 64)
cast_x = pypto.cast(x, pypto.DataType.DT_FP32)
cast_cos = pypto.reshape(cast_cos, [x.shape[0], 1, x.shape[2]])
cast_sin = pypto.reshape(cast_sin, [x.shape[0], 1, x.shape[2]])
x_view = pypto.reshape(cast_x, [x.shape[0], x.shape[1], x.shape[2] // 2, 2])
pypto.set_vec_tile_shapes(1, 64, 64, 64)
x_trans = pypto.transpose(x_view, 2, 3)
x_re_second = pypto.reshape(x_trans, x.shape)
pypto.set_vec_tile_shapes(1, 64, 64)
x_rotate = rotate_half(x_re_second)
x_rotate_trs_1 = pypto.transpose(x_rotate, 1, 2)
x_rotate_reshape_1 = pypto.reshape(x_rotate_trs_1, [
x_rotate_trs_1.shape[0], 2, x_rotate_trs_1.shape[1] // 2, x_rotate_trs_1.shape[2]])
pypto.set_vec_tile_shapes(1, 64, 64, 64)
x_rotate_trs_2 = pypto.transpose(x_rotate_reshape_1, 1, 2)
x_rotate_reshape_2 = pypto.reshape(x_rotate_trs_2, x_rotate.shape)
pypto.set_vec_tile_shapes(1, 64, 64)
x_rotate_res = pypto.transpose(x_rotate_reshape_2, 1, 2)
pypto.set_vec_tile_shapes(1, 64, 64)
x_embed = cast_x * cast_cos + x_rotate_res * cast_sin
x_embed_cast = pypto.cast(x_embed, x.dtype)
return x_embed_cast
def mla_prolog_v4_compute(x, wq_a, wq_b, wkv, rmsnorm_gamma_cq, rmsnorm_gamma_ckv, cos, sin, q_out, kv_out, qr_out, attrs, configs):
t = x.shape[0]
h = x.shape[1]
q_lora_rank = rmsnorm_gamma_cq.shape[0]
head_dim = rmsnorm_gamma_ckv.shape[0]
head_num = wq_b.shape[1] // head_dim
rope_dim = cos.shape[1]
gamma_cq_2d = pypto.reshape(rmsnorm_gamma_cq, [1, rmsnorm_gamma_cq.shape[0]], inplace=True)
gamma_ckv_2d = pypto.reshape(rmsnorm_gamma_ckv, [1, rmsnorm_gamma_ckv.shape[0]], inplace=True)
pypto.set_vec_tile_shapes(4, q_lora_rank)
gamma_cq_2d_fp32 = pypto.cast(gamma_cq_2d, pypto.DataType.DT_FP32)
gamma_ckv_2d_fp32 = pypto.cast(gamma_ckv_2d, pypto.DataType.DT_FP32)
unroll_list = configs.unroll_list
for tIdx, unrollLength in pypto.loop_unroll(0, t, 1, name="MLA_BS_LOOP", idx_name="bs_offset",
unroll_list=unroll_list, ):
t_tile = unrollLength
x_tile = pypto.view(x, [t_tile, h], [tIdx, 0], valid_shape=[t_tile, h])
pypto.set_semantic_label("wqa-linear")
pypto.set_cube_tile_shapes([32, 32], [512, 512], [64, 64])
q = pypto.matmul(x_tile, wq_a, pypto.DataType.DT_BF16)
pypto.set_semantic_label("q-rmsnorm with weight")
pypto.set_vec_tile_shapes(8, q_lora_rank)
qr = rms_norm(q, attrs.eps)
qr = pypto.mul(qr, gamma_cq_2d_fp32)
qr = pypto.cast(qr, pypto.DataType.DT_BF16)
pypto.assemble(qr, [tIdx, 0], qr_out)
pypto.set_semantic_label("wqb-linear")
pypto.set_cube_tile_shapes([32, 32], [128, 128], [256, 256])
q = pypto.matmul(qr, wq_b, pypto.DataType.DT_BF16)
q_3d = pypto.reshape(q, [t_tile, head_num, head_dim])
pypto.set_vec_tile_shapes(4, 64, 64)
qr2_3d = rms_norm(q_3d, attrs.eps)
qr2_3d = pypto.cast(qr2_3d, pypto.DataType.DT_BF16)
pypto.set_vec_tile_shapes(4, 64)
cos_2d = pypto.view(cos, [t_tile, rope_dim], [tIdx, 0], valid_shape=[t_tile, rope_dim])
sin_2d = pypto.view(sin, [t_tile, rope_dim], [tIdx, 0], valid_shape=[t_tile, rope_dim])
pypto.set_vec_tile_shapes(4, 64, 64)
qr2_3d_nope = pypto.view(qr2_3d, [t_tile, head_num, head_dim-rope_dim], [0, 0, 0], valid_shape=[t_tile, head_num, head_dim-rope_dim])
qr2_3d_rope = pypto.view(qr2_3d, [t_tile, head_num, rope_dim], [0, 0, head_dim-rope_dim], valid_shape=[t_tile, head_num, rope_dim])
qr2_3d_rope = rope_3d(qr2_3d_rope, cos_2d, sin_2d)
qr2_3d = pypto.concat([qr2_3d_nope, qr2_3d_rope], -1)
pypto.assemble(qr2_3d, [tIdx, 0, 0], q_out)
pypto.set_semantic_label("wkv-linear")
pypto.set_cube_tile_shapes([32, 32], [256, 256], [128, 128])
kv = pypto.matmul(x_tile, wkv, pypto.DataType.DT_BF16)
pypto.set_vec_tile_shapes(4, 64)
kv_norm = rms_norm(kv, attrs.eps)
kv_norm = pypto.mul(kv_norm, gamma_ckv_2d_fp32)
kv_norm = pypto.cast(kv_norm, pypto.DataType.DT_BF16)
kv_norm_nope = pypto.view(kv_norm, [t_tile, head_dim-rope_dim], [0, 0], valid_shape=[t_tile, head_dim-rope_dim])
kv_norm_rope = pypto.view(kv_norm, [t_tile, rope_dim], [0, head_dim-rope_dim], valid_shape=[t_tile, rope_dim])
kv_norm_rope = rope_2d(kv_norm_rope, cos_2d, sin_2d)
kv_norm = pypto.concat([kv_norm_nope, kv_norm_rope], -1)
pypto.assemble(kv_norm, [tIdx, 0], kv_out)
class MLAKernelMAnager:
def __init__(self):
self.vec_all_shape = {}
self.t_vec = [4096, 128, 64, 32, 16, 1]
self.wq_a_shape = [4096, 1024]
self.wq_b_shape = [1024, 64 * 512]
self.wkv_shape = [4096, 512]
self.rmsnorm_gamma_cq_shape = [1024]
self.rmsnorm_gamma_ckv_shape = [512]
for t in self.t_vec:
x_shape = [t, 4096]
rops_cos_shape = [t, 64]
q_out_shape = [t, 64, 512]
kv_out_shape = [t, 512]
qr_out_shape = [t, 1024]
self.vec_all_shape[t] = [x_shape, self.wq_a_shape, self.wq_b_shape, self.wkv_shape, self.rmsnorm_gamma_cq_shape, \
self.rmsnorm_gamma_ckv_shape, rops_cos_shape, rops_cos_shape, q_out_shape, kv_out_shape, qr_out_shape]
def infer_controlflow_shape(self, *args):
global vec_all_shape, t_vec
if not args:
return [v for v in self.vec_all_shape.values()]
x_shape = args[0]
for t in self.t_vec:
if x_shape[0]>=t:
return self.vec_all_shape[t]
manager = MLAKernelMAnager()
@pypto.frontend.jit(runtime_options={
"stitch_function_max_num": 128,
},
infer_controlflow_shape=manager.infer_controlflow_shape,
)
def mla_prolog_v4(
x: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
wq_a: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
wq_b: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
wkv: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
rmsnorm_gamma_cq: pypto.Tensor([pypto.STATIC], pypto.DT_BF16),
rmsnorm_gamma_ckv: pypto.Tensor([pypto.STATIC], pypto.DT_BF16),
cos: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
sin: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
q_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
kv_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
qr_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
attrs, configs):
pypto.experimental.set_operation_options(combine_axis=True)
mla_prolog_v4_compute(x, wq_a, wq_b, wkv, rmsnorm_gamma_cq, rmsnorm_gamma_ckv, cos, sin, q_out, kv_out, qr_out, attrs, configs)
@allow_in_graph
def mla_prolog_v4_in(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv):
output_q_data = torch.zeros([token_x.size(0), wq_b.size(1) // gamma_ckv.size(0), gamma_ckv.size(0)], dtype=token_x.dtype, device=f'{token_x.device}')
output_kv_data = torch.zeros([token_x.size(0), gamma_ckv.size(0)], dtype=token_x.dtype, device=f'{token_x.device}')
output_qr_data = torch.zeros([token_x.size(0), gamma_cq.size(0)], dtype=token_x.dtype, device=f'{token_x.device}')
check_input_output_shape_dtype(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv,
output_q_data, output_kv_data, output_qr_data)
attrs = MlaPrologV4Attrs(eps=1e-6, layout_query="TND", layout_key="PA_BSND")
configs = MlaPrologV4Configs(unroll_list=[128, 64, 32, 16, 1],
cube_l1_reuse_setting={2: 4},
mg_copyin_upper_bound=2 * 1024 * 1024,
pg_upper_bound=8192,
block_size=128,
t_sub_tile=1,
chunk_size=2)
params_info = [token_x, wq_a, wq_b, wkv, gamma_cq, gamma_ckv, \
rope_cos, rope_sin, output_q_data, output_kv_data, output_qr_data]
mla_prolog_v4(*params_info, attrs, configs)
return output_q_data, output_kv_data, output_qr_data
pyptolib = torch.library.Library("pypto", "FRAGMENT")
pyptolib.define("mla_prolog(Tensor token_x, Tensor wq_a, Tensor wq_b, Tensor wkv, Tensor rope_cos, Tensor rope_sin, \
Tensor gamma_cq, Tensor gamma_ckv) -> (Tensor, Tensor, Tensor)")
@torch.library.impl(pyptolib, "mla_prolog", "Meta")
def mla_prolog(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv):
q_out = torch.empty([token_x.size(0), wq_b.size(1) // gamma_ckv.size(0), gamma_ckv.size(0)], dtype=token_x.dtype, device=token_x.device)
kv_out = torch.empty([token_x.size(0), gamma_ckv.size(0)], dtype=token_x.dtype, device=token_x.device)
qr_out = torch.empty([token_x.size(0), gamma_cq.size(0)], dtype=token_x.dtype, device=token_x.device)
return q_out, kv_out, qr_out
@torch.library.impl(pyptolib, "mla_prolog", "NPU")
def mla_prolog(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv):
return mla_prolog_v4_in(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv)
def mla_prolog_pypto(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv):
return torch.ops.pypto.mla_prolog(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv)