"""
"""
from dataclasses import dataclass
import math
from typing import List, Tuple
import torch
from torch._dynamo import allow_in_graph
from torch._subclasses.fake_tensor import FakeTensor
import torch_npu
import pypto
from pypto import pypto_impl
from pypto.operation import op_wrapper
"""
MLA Prolog Quantization Module
This module implements MLA (Multi-head Latent Attention) Prolog quantization
for DeepSeek V4 model. It converts hidden states to query, key, and value
projections with support for quantization and RoPE (Rotary Position Embedding).
Main Functions:
- mla_prolog_quant_compute: Core MLA prolog computation with quantization
- pre_compute_2d: Pre-computation for query and key-value projections
- rms_norm: RMS normalization implementation
- quant: Quantization function with symmetry and smooth factor support
- dequant: Dequantization function
- rope_v2: 2D RoPE implementation
- rope_3d_v2: 3D RoPE implementation
- k_nope_quant: Key quantization function
Example:
See test_mla_prolog_quant_v4.py for usage examples.
"""
SHAPE_DIM_2 = 2
SHAPE_DIM_3 = 3
NUM_0 = 0
NUM_1 = 1
NUM_2 = 2
NUM_3 = 3
NUM_4096 = 4096
NUM_512 = 512
TILE_CUBE_DIM = 6
Q_PARAM_DIM = 2
NZ_DIM = 4
COS_SIN_DIM = 2
L0M_INDEX = 0
L1M_INDEX = 1
L0K_INDEX = 2
L1K_INDEX = 3
L0N_INDEX = 4
L1N_INDEX = 5
SCATTER_DIM = -2
NZ_FIRST_DIM = 16
NZ_B8_C0 = 32
NZ_B16_C0 = 16
VEC_TILE_256 = 256
VEC_TILE_128 = 128
VEC_TILE_64 = 64
VEC_TILE_8 = 8
VEC_TILE_4 = 4
VEC_TILE_32 = 32
@dataclass
class MlaTileConfigs:
two_dim_tile: List
three_dim_tile: List
four_dim_tile: List
vec_tile: List
@dataclass
class MlaPrologV4Output:
q: torch.tensor
kv: torch.tensor
qr: torch.tensor
@dataclass
class MlaPrologV4Attrs:
eps: float
@dataclass
class MlaPrologV4Configs:
unroll_list: List[int]
cube_l1_reuse_setting: dict[int, int]
mg_copyin_upper_bound: int
pg_upper_bound: int
block_size: int
t_sub_tile: int
chunk_size: int
def check_input_output_shape_dtype(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv,
wq_b_scale, output_q_data, output_kv_data, output_qr_data, output_qr_scale_data):
assert token_x.size(1) == 4096 and token_x.dim() == 2, \
f"expected token_x dim num 2, token_x axis1 4096, but got {token_x.shape}"
assert wq_a.dim() == 2 and wq_a.size(0) == 4096 and wq_a.size(1) == 1024, \
f"expected wq_a dim num 2 residual axis0 4096, wq_a axis1 1024, but got {wq_a.shape}"
assert wq_b.dim() == 2 and wq_b.size(0) == 1024 and wq_b.size(1) == 32768, \
f"expected wq_b dim num 2, wq_b axis0 1024, wq_b axis1 32768, but got {wq_b.shape}"
assert wkv.dim() == 2 and wkv.size(0) == 4096 and wkv.size(1) == 512, \
f"expected wkv dim num 2, wkv axis0 4096, wkv axis1 512, but got {wkv.shape}"
assert rope_cos.dim() == 2 and rope_cos.size(1) == 64, \
f"expected rope_cos dim num 2, rope_cos axis1 64, but got {rope_cos.shape}"
assert rope_sin.dim() == 2 and rope_sin.size(1) == 64, \
f"expected rope_sin dim num 2, rope_sin axis1 64, but got {rope_sin.shape}"
assert gamma_cq.dim() == 1 and gamma_cq.size(0) == 1024, \
f"expected gamma_cq dim num 1, gamma_cq axis0 1024, but got {gamma_cq.shape}"
assert gamma_ckv.dim() == 1 and gamma_ckv.size(0) == 512, \
f"expected gamma_ckv dim num 1, gamma_ckv axis0 512, but got {gamma_ckv.shape}"
assert wq_b_scale.dim() == 2 and wq_b_scale.size(0) == 32768 and wq_b_scale.size(1) == 1, \
f"expected wq_b_scale dim num 2, wq_b_scale axis0 32768, wq_b_scale axis1 1, but got {wq_b_scale.shape}"
assert output_q_data.dim() == 3 and output_q_data.size(1) == 64 and output_q_data.size(2) == 512, \
f"expected output_q_data dim num 3, output_q_data axis1 64, output_q_data axis2 512, \
but got {output_q_data.shape}"
assert output_kv_data.dim() == 2 and output_kv_data.size(1) == 512, \
f"expected output_kv_data dim num 2, output_kv_data axis1 512, but got {output_kv_data.shape}"
assert output_qr_data.dim() == 2 and output_qr_data.size(1) == 1024, \
f"expected output_qr_data dim num 2, output_qr_data axis1 4096, but got {output_qr_data.shape}"
assert output_qr_scale_data.dim() == 2 and output_qr_scale_data.size(1) == 1, \
f"expected output_qr_scale_data dim num 2, output_qr_scale_data axis1 1, but got {output_qr_scale_data.shape}"
assert token_x.dtype == torch.bfloat16, f"token_x.dtype is {token_x.dtype}, expected torch.bfloat16"
assert wq_a.dtype == torch.bfloat16, f"wq_a.dtype is {wq_a.dtype}, expected torch.bfloat16"
assert wq_b.dtype == torch.int8, f"wq_b.dtype is {wq_b.dtype}, expected torch.int8"
assert wkv.dtype == torch.bfloat16, f"wkv.dtype is {wkv.dtype}, expected torch.bfloat16"
assert rope_cos.dtype == torch.bfloat16, f"rope_cos.dtype is {rope_cos.dtype}, expected torch.bfloat16"
assert rope_sin.dtype == torch.bfloat16, f"rope_sin.dtype is {rope_sin.dtype}, expected torch.bfloat16"
assert gamma_cq.dtype == torch.bfloat16, f"gamma_cq.dtype is {gamma_cq.dtype}, expected torch.bfloat16"
assert gamma_ckv.dtype == torch.bfloat16, f"gamma_ckv.dtype is {gamma_ckv.dtype}, expected torch.bfloat16"
assert wq_b_scale.dtype == torch.float32, \
f"wq_b_scale.dtype is {wq_b.dtype}, expected torch.float32"
assert output_q_data.dtype == torch.bfloat16, \
f"output_q_data.dtype is {output_q_data.dtype}, expected torch.bfloat16"
assert output_kv_data.dtype == torch.bfloat16, \
f"output_kv_data.dtype is {output_kv_data.dtype}, expected torch.bfloat16"
assert output_qr_data.dtype == torch.int8, \
f"output_qr_data.dtype is {output_qr_data.dtype}, expected torch.int8"
assert output_qr_scale_data.dtype == torch.float32, \
f"output_qr_scale_data.dtype is {output_qr_scale_data.dtype}, expected torch.float32"
def quant(
input_tensor: pypto.Tensor,
is_symmetry: bool = True,
has_smooth_factor: bool = False,
smooth_factor: pypto.Tensor = None) -> Tuple[pypto.Tensor, pypto.Tensor]:
"""Quantize input tensor to INT8 with optional symmetry and smooth factor.
Performs quantization to INT8 format with support for:
- Symmetric quantization (centered around zero)
- Asymmetric quantization (with offset)
- Smooth quantization factor (for improved quantization quality)
Args:
input_tensor: Input tensor to quantize
is_symmetry: If True, use symmetric quantization (range: [-127, 127])
If False, use asymmetric quantization (range: [0, 255])
has_smooth_factor: Whether to apply smooth quantization factor
smooth_factor: Smooth factor tensor to multiply before quantization
Returns:
Tuple of (quantized_tensor, dequant_scale):
- quantized_tensor: INT8 quantized tensor
- dequant_scale: FP32 scale factor for dequantization
Note:
For symmetric quantization, scale = max(|input|) / 127.0
For asymmetric quantization, scale = (max - min) / 255.0
"""
if input_tensor.dtype != pypto.DT_FP32:
input_tensor_fp32 = pypto.cast(input_tensor, pypto.DT_FP32)
else:
input_tensor_fp32 = input_tensor
if has_smooth_factor:
input_tensor_fp32 = pypto.mul(input_tensor_fp32, smooth_factor)
if is_symmetry:
abs_res = pypto.abs(input_tensor_fp32)
max_value = pypto.amax(abs_res, -1, keepdim=True)
scale_quant = pypto.div(pypto.full(max_value.shape, 127.0, pypto.DT_FP32), max_value)
out_fp32 = pypto.mul(input_tensor_fp32, scale_quant)
out_int32 = pypto.cast(out_fp32, pypto.DT_INT32, pypto.CastMode.CAST_RINT)
out_half = pypto.cast(out_int32, pypto.DT_FP16, pypto.CastMode.CAST_ROUND)
out_int8 = pypto.cast(out_half, pypto.DT_INT8, pypto.CastMode.CAST_TRUNC, satmode=pypto.SaturationMode.ON)
scale_de_quant = pypto.div(pypto.full(scale_quant.shape, 1.0, pypto.DT_FP32), scale_quant)
return out_int8, scale_de_quant
else:
max_value = pypto.amax(input_tensor_fp32, -1, keepdim=True)
min_value = pypto.amin(input_tensor_fp32, -1, keepdim=True)
scale_de_quant = pypto.max(pypto.div(pypto.sub(max_value, min_value), 255.0), 1e-12)
offset = pypto.sub(127.0, pypto.div(max_value, scale_de_quant))
scale_quant = pypto.div(pypto.full(max_value.shape, 1.0, pypto.DT_FP32), max_value)
out_fp32 = pypto.mul(input_tensor_fp32, scale_quant)
out_int32 = pypto.cast(out_fp32, pypto.DT_INT32, pypto.CastMode.CAST_RINT)
out_half = pypto.cast(out_int32, pypto.DT_FP16, pypto.CastMode.CAST_ROUND)
out_int8 = pypto.cast(out_half, pypto.DT_INT8, pypto.CastMode.CAST_TRUNC, satmode=pypto.SaturationMode.ON)
return out_int8, scale_de_quant
def dequant(
input_tensor: pypto.Tensor, scale: pypto.Tensor, w_scale: pypto.Tensor
) -> pypto.Tensor:
"""Dequantize INT8 tensor back to floating point.
Converts quantized INT8 tensor back to floating point by applying
dequantization scales. Supports per-token and per-channel scaling.
Args:
input_tensor: Quantized INT8 input tensor
scale: Per-token or per-channel dequantization scale
w_scale: Weight dequantization scale (per-channel)
Returns:
Dequantized tensor in the specified dtype
Note:
Dequantization formula: output = (input * scale) * w_scale
The computation is done in FP32, then cast to target dtype.
"""
dequant_res = pypto.cast(input_tensor, pypto.DT_FP32)
dequant_res = dequant_res * scale
dequant_res = dequant_res * w_scale
return dequant_res
def rms_norm(input_tensor: pypto.Tensor, epsilon: float) -> pypto.Tensor:
"""Compute RMS (Root Mean Square) normalization.
Applies RMS normalization to the input tensor. RMS normalization is similar
to LayerNorm but uses root mean square instead of standard deviation.
Formula: output = gamma * input / sqrt(mean(input^2) + epsilon)
Args:
input_tensor: Input tensor to normalize
epsilon: Small constant added to variance to avoid division by zero
Returns:
Normalized tensor with the same shape as input
Note:
The normalization is performed along the last dimension.
Computation is done in FP32 for numerical stability.
"""
dim = len(input_tensor.shape)
y = pypto.mul(input_tensor, input_tensor)
y = pypto.sum(y, -1, keepdim=True)
y = pypto.mul(y, 1.0 / input_tensor.shape[dim - 1])
y = pypto.add(y, epsilon)
y = pypto.sqrt(y)
ones_vector = pypto.full(y.shape, 1.0, pypto.DT_FP32)
y = pypto.div(ones_vector, y)
y = pypto.mul(input_tensor, y)
return y
def rotate_half(input_tensor: pypto.Tensor) -> pypto.Tensor:
"""Rotate half of the tensor dimensions for RoPE computation.
Splits the last dimension in half and applies rotation transformation:
[-x2, x1] where x1 is the first half and x2 is the second half.
This is a key component of RoPE (Rotary Position Embedding).
Args:
input_tensor: Input tensor with last dimension divisible by 2
Returns:
Rotated tensor with same shape as input
Raises:
AssertionError: If input dimension is less than 1 or last dimension
is not divisible by 2
"""
shape = input_tensor.shape
shape_size = len(shape)
new_shape = list(shape)
new_shape[shape_size - 1] //= 2
offset1 = [0] * shape_size
offset2 = [0] * shape_size
offset2[shape_size - 1] = new_shape[shape_size - 1]
x1 = pypto.view(input_tensor, new_shape, offset1)
x2 = pypto.view(input_tensor, new_shape, offset2)
return pypto.concat([x2 * (-1.0), x1 + 0.0], -1)
def rope_2d(
x: pypto.Tensor, cos: pypto.Tensor, sin: pypto.Tensor) -> pypto.Tensor:
"""Apply 2D Rotary Position Embedding (RoPE) version 2.
Implements RoPE transformation for 2D tensors with optimized tiling.
The function reshapes and transposes the input before applying rotation.
Args:
x: Input tensor of shape (seq_size, d_r)
cos: Cosine values for RoPE, shape (seq_size, d_r)
sin: Sine values for RoPE, shape (seq_size, d_r)
tile_config: RopeTileShapeConfig object containing tiling parameters:
- two_dim: Tile shape for 2D operations
- three_dim: Tile shape for 3D reshape operations
Returns:
Tensor with RoPE applied, same shape as input x
Note:
The function performs reshape and transpose operations before applying
rotation to optimize memory access patterns.
"""
assert len(x.shape) == 2 and len(cos.shape) == 2 and len(sin.shape) == 2
input_dtype = x.dtype
pypto.set_vec_tile_shapes(4, 64, 64)
x_cast = pypto.cast(x, pypto.DT_FP32)
x_view = pypto.reshape(x, [x.shape[0], x.shape[1]//2, 2])
x_trans = pypto.transpose(x_view, 1, 2)
x_re_second = pypto.reshape(x_trans, x.shape)
x_t = rotate_half(x_re_second)
x_new = pypto.reshape(x_t, [x.shape[0], 2, x.shape[1]//2])
x_new_trans = pypto.transpose(x_new, 1, 2)
x_new_r = pypto.reshape(x_new_trans, x.shape)
x_new_cast = pypto.cast(x_new_r, pypto.DT_FP32)
pypto.set_vec_tile_shapes(4, 64, 64)
x_embed = x_cast * cos + x_new_cast * sin
return pypto.cast(x_embed, input_dtype)
def rope_3d(x: pypto.Tensor, cos: pypto.Tensor, sin: pypto.Tensor, tile_configs: MlaTileConfigs) -> pypto.Tensor:
"""Apply inverse 3D Rotary Position Embedding.
"""
assert (len(x.shape) == SHAPE_DIM_3 and len(cos.shape) == SHAPE_DIM_2 and len(sin.shape) == SHAPE_DIM_2)
pypto.set_vec_tile_shapes(*tile_configs.three_dim_tile)
cast_x = pypto.cast(x, pypto.DataType.DT_FP32)
cast_cos = pypto.reshape(cos, [x.shape[0], 1, x.shape[2]])
cast_sin = pypto.reshape(sin, [x.shape[0], 1, x.shape[2]])
x_view = pypto.reshape(cast_x, [x.shape[0], x.shape[1], x.shape[2] // 2, 2])
pypto.set_vec_tile_shapes(*tile_configs.four_dim_tile)
x_trans = pypto.transpose(x_view, 2, 3)
x_re_second = pypto.reshape(x_trans, x.shape)
pypto.set_vec_tile_shapes(*tile_configs.three_dim_tile)
x_rotate = rotate_half(x_re_second)
x_rotate_trs_1 = pypto.transpose(x_rotate, 1, 2)
x_rotate_reshape_1 = pypto.reshape(x_rotate_trs_1, [
x_rotate_trs_1.shape[0], 2, x_rotate_trs_1.shape[1] // 2, x_rotate_trs_1.shape[2]])
pypto.set_vec_tile_shapes(*tile_configs.four_dim_tile)
x_rotate_trs_2 = pypto.transpose(x_rotate_reshape_1, 1, 2)
x_rotate_trs_2 = x_rotate_trs_2 + 0.0
x_rotate_reshape_2 = pypto.reshape(x_rotate_trs_2, x_rotate.shape)
pypto.set_vec_tile_shapes(*tile_configs.three_dim_tile)
x_rotate_res = pypto.transpose(x_rotate_reshape_2, 1, 2)
pypto.set_vec_tile_shapes(*tile_configs.three_dim_tile)
x_embed = cast_x * cast_cos + x_rotate_res * cast_sin
x_embed_cast = pypto.cast(x_embed, x.dtype)
return x_embed_cast
def mla_prolog_v4_compute(x, wq_a, wq_b, wkv, rmsnorm_gamma_cq, rmsnorm_gamma_ckv, cos, sin, \
wq_b_scale, q_out, kv_out, qr_out, qr_scale_out, attrs, configs, tile_configs):
t = x.shape[0]
h = x.shape[1]
q_lora_rank = rmsnorm_gamma_cq.shape[0]
head_dim = rmsnorm_gamma_ckv.shape[0]
head_num = wq_b.shape[1] // head_dim
rope_dim = cos.shape[1]
k_tile = 2048
gamma_cq_2d = pypto.reshape(rmsnorm_gamma_cq, [1, rmsnorm_gamma_cq.shape[0]], inplace=True)
gamma_ckv_2d = pypto.reshape(rmsnorm_gamma_ckv, [1, rmsnorm_gamma_ckv.shape[0]], inplace=True)
wq_b_scale = pypto.reshape(wq_b_scale, [1, wq_b_scale.shape[0]], inplace=True)
pypto.set_vec_tile_shapes(4, q_lora_rank)
gamma_cq_2d_fp32 = pypto.cast(gamma_cq_2d, pypto.DataType.DT_FP32)
gamma_ckv_2d_fp32 = pypto.cast(gamma_ckv_2d, pypto.DataType.DT_FP32)
unroll_list = configs.unroll_list
for tIdx, unrollLength in pypto.loop_unroll(0, t, 1, name="MLA_BS_LOOP", idx_name="bs_offset",
unroll_list=unroll_list, ):
t_tile = unrollLength
tile_bs = min(t_tile, 128)
pypto.set_vec_tile_shapes(4, 4096)
x_tile = pypto.view(x, [t_tile, h], [tIdx, 0], valid_shape=[t_tile, h])
pypto.set_semantic_label("wqa-linear")
pypto.set_cube_tile_shapes([tile_bs, tile_bs], [256, 256], [128, 128])
for i in range(2):
x_tile1 = pypto.view(x_tile, [t_tile, k_tile], [0, i*k_tile])
wq_a_tile1 = pypto.view(wq_a, [k_tile, q_lora_rank], [i*k_tile, 0])
if i==0:
q = pypto.matmul(x_tile1, wq_a_tile1, pypto.DT_FP32)
else:
q = q + pypto.matmul(x_tile1, wq_a_tile1, pypto.DT_FP32)
pypto.set_semantic_label("q-rmsnorm with weight")
qr = rms_norm(q, attrs.eps)
qr = pypto.mul(qr, gamma_cq_2d_fp32)
qr_quant, qr_scale = quant(qr)
pypto.assemble(qr_quant, [tIdx, 0], qr_out)
pypto.assemble(qr_scale, [tIdx, 0], qr_scale_out)
pypto.set_semantic_label("wqb-linear")
pypto.set_cube_tile_shapes([tile_bs, tile_bs], [128, 1024, 512], [256, 256])
qb = pypto.matmul(qr_quant, wq_b, pypto.DataType.DT_INT32)
pypto.set_vec_tile_shapes(4, 4096)
qb_dequant = dequant(qb, qr_scale, wq_b_scale)
q_3d = pypto.reshape(qb_dequant, [t_tile, head_num, head_dim])
pypto.set_vec_tile_shapes(4, 8, 512)
qr2_3d = rms_norm(q_3d, attrs.eps)
qr2_3d_cast = pypto.cast(qr2_3d, pypto.DataType.DT_BF16)
qr2_3d_nope = pypto.view(qr2_3d_cast, [t_tile, head_num, head_dim-rope_dim], \
[0, 0, 0], valid_shape=[t_tile, head_num, head_dim-rope_dim])
pypto.assemble(pypto.clone(qr2_3d_nope), [tIdx, 0, 0], q_out)
pypto.set_vec_tile_shapes(*tile_configs.vec_tile)
cos_2d = pypto.view(cos, [t_tile, rope_dim], [tIdx, 0], valid_shape=[t_tile, rope_dim])
sin_2d = pypto.view(sin, [t_tile, rope_dim], [tIdx, 0], valid_shape=[t_tile, rope_dim])
cast_cos = pypto.cast(cos_2d, pypto.DT_FP32)
cast_sin = pypto.cast(sin_2d, pypto.DT_FP32)
pypto.set_vec_tile_shapes(4, 64, 64)
qr2_3d_rope = pypto.view(qr2_3d_cast, [t_tile, head_num, rope_dim], [0, 0, head_dim-rope_dim], valid_shape=[t_tile, head_num, rope_dim])
qr2_3d_rope = rope_3d(qr2_3d_rope, cast_cos, cast_sin, tile_configs)
pypto.assemble(qr2_3d_rope, [tIdx, 0, head_dim-rope_dim], q_out)
pypto.set_semantic_label("wkv-linear")
pypto.set_cube_tile_shapes([tile_bs, tile_bs], [256, 512], [128, 128], True)
kv = pypto.matmul(x_tile, wkv, pypto.DataType.DT_FP32)
pypto.set_vec_tile_shapes(8, 512)
kv_norm = rms_norm(kv, attrs.eps)
kv_norm = pypto.mul(kv_norm, gamma_ckv_2d_fp32)
kv_norm_cast = pypto.cast(kv_norm, pypto.DataType.DT_BF16)
kv_norm_nope = pypto.view(kv_norm_cast, [t_tile, head_dim-rope_dim], [0, 0], valid_shape=[t_tile, head_dim-rope_dim])
pypto.assemble(pypto.clone(kv_norm_nope), [tIdx, 0], kv_out)
kv_norm_rope = pypto.view(kv_norm_cast, [t_tile, rope_dim], [0, head_dim-rope_dim], valid_shape=[t_tile, rope_dim])
kv_norm_rope = rope_2d(kv_norm_rope, cast_cos, cast_sin)
pypto.assemble(kv_norm_rope, [tIdx, head_dim-rope_dim], kv_out)
class MLAKernelMAnager:
def __init__(self):
self.vec_all_shape = {}
self.t_vec = [4096, 128, 64, 32, 16, 1]
self.wq_a_shape = [4096, 1024]
self.wq_b_shape = [1024, 64 * 512]
self.wkv_shape = [4096, 512]
self.rmsnorm_gamma_cq_shape = [1024]
self.rmsnorm_gamma_ckv_shape = [512]
self.wq_b_scale_shape = [64 * 512, 1]
for t in self.t_vec:
x_shape = [t, 4096]
rops_cos_shape = [t, 64]
q_out_shape = [t, 64, 512]
kv_out_shape = [t, 512]
qr_out_shape = [t, 1024]
qr_scale_out_shape = [t, 1]
self.vec_all_shape[t] = [x_shape, self.wq_a_shape, self.wq_b_shape, self.wkv_shape, self.rmsnorm_gamma_cq_shape, self.rmsnorm_gamma_ckv_shape, \
rops_cos_shape, rops_cos_shape, self.wq_b_scale_shape, q_out_shape, \
kv_out_shape, qr_out_shape, qr_scale_out_shape]
def infer_controlflow_shape(self, *args):
global vec_all_shape, t_vec
if not args:
return [v for v in self.vec_all_shape.values()]
x_shape = args[0]
for t in self.t_vec:
if x_shape[0]>=t:
return self.vec_all_shape[t]
manager = MLAKernelMAnager()
@pypto.frontend.jit(runtime_options={
"stitch_function_max_num": 128,
"device_sched_mode": 1
},
pass_options={
"vec_nbuffer_setting": {-1: 2}
},
infer_controlflow_shape=manager.infer_controlflow_shape,
)
def mla_prolog_v4(
x: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC] , pypto.DT_BF16),
wq_a: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16, format=pypto.TileOpFormat.TILEOP_NZ),
wq_b: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT8, format=pypto.TileOpFormat.TILEOP_NZ),
wkv: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16, format=pypto.TileOpFormat.TILEOP_NZ),
rmsnorm_gamma_cq: pypto.Tensor([pypto.STATIC], pypto.DT_BF16),
rmsnorm_gamma_ckv: pypto.Tensor([pypto.STATIC], pypto.DT_BF16),
cos: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
sin: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
wq_b_scale: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
q_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC] , pypto.DT_BF16),
kv_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC] , pypto.DT_BF16),
qr_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC] , pypto.DT_INT8),
qr_scale_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC] , pypto.DT_FP32),
attrs, configs, tile_configs):
pypto.experimental.set_operation_options(combine_axis=True)
mla_prolog_v4_compute(x, wq_a, wq_b, wkv, rmsnorm_gamma_cq, rmsnorm_gamma_ckv, cos, sin, wq_b_scale, q_out, kv_out, qr_out, qr_scale_out, attrs, configs, tile_configs)
@allow_in_graph
def mla_prolog_v4_in(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv, wq_b_scale):
output_q_data = torch.empty([token_x.size(0), wq_b.size(1) // gamma_ckv.size(0), gamma_ckv.size(0)], \
dtype=token_x.dtype, device=f'{token_x.device}')
output_kv_data = torch.empty([token_x.size(0), gamma_ckv.size(0)], dtype=token_x.dtype, device=f'{token_x.device}')
output_qr_data = torch.empty([token_x.size(0), gamma_cq.size(0)], dtype=torch.int8, device=f'{token_x.device}')
output_qr_scale_data = torch.empty([token_x.size(0), 1], dtype=torch.float32, device=f'{token_x.device}')
check_input_output_shape_dtype(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv,
wq_b_scale, output_q_data, output_kv_data, output_qr_data, output_qr_scale_data)
attrs = MlaPrologV4Attrs(eps=1e-6)
tile_configs = MlaTileConfigs(
two_dim_tile=[1, 64],
three_dim_tile=[1, 64, 64],
four_dim_tile=[1, 64, 64, 64],
vec_tile=[max(1, token_x.shape[0]//16), 64]
)
configs = MlaPrologV4Configs(unroll_list=[128, 64, 32, 16, 1],
cube_l1_reuse_setting={2: 4},
mg_copyin_upper_bound=2 * 1024 * 1024,
pg_upper_bound=8192,
block_size=128,
t_sub_tile=1,
chunk_size=2)
params_info = [token_x, wq_a, wq_b, wkv, gamma_cq, gamma_ckv, rope_cos, rope_sin, wq_b_scale, output_q_data, output_kv_data, output_qr_data, output_qr_scale_data]
mla_prolog_v4(*params_info, attrs, configs, tile_configs)
return output_q_data, output_kv_data, output_qr_data, output_qr_scale_data
pyptolib = torch.library.Library("pypto", "FRAGMENT")
pyptolib.define("mla_prolog_quant(Tensor token_x, Tensor wq_a, Tensor wq_b, Tensor wkv, Tensor rope_cos, Tensor rope_sin, \
Tensor gamma_cq, Tensor gamma_ckv, Tensor wq_b_scale) -> (Tensor, Tensor, Tensor, Tensor)")
@torch.library.impl(pyptolib, "mla_prolog_quant", "Meta")
def mla_prolog_quant(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv, wq_b_scale):
q_out = torch.empty([token_x.size(0), wq_b.size(1) // gamma_ckv.size(0), gamma_ckv.size(0)], dtype=token_x.dtype, device=token_x.device)
kv_out = torch.empty([token_x.size(0), gamma_ckv.size(0)], dtype=token_x.dtype, device=token_x.device)
qr_out = torch.empty([token_x.size(0), gamma_cq.size(0)], dtype=torch.int8, device=token_x.device)
qr_scale_out = torch.empty([token_x.size(0), 1], dtype=torch.float32, device=token_x.device)
return q_out, kv_out, qr_out, qr_scale_out
@torch.library.impl(pyptolib, "mla_prolog_quant", "NPU")
def mla_prolog_quant(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv, wq_b_scale):
return mla_prolog_v4_in(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv, wq_b_scale)
def mla_prolog_quant_pypto(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv, wq_b_scale):
return torch.ops.pypto.mla_prolog_quant(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv, wq_b_scale)