""" """
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import pypto
from conftest import duration_estimate
SHAPE_DIM_0 = 0
SHAPE_DIM_1 = 1
SHAPE_DIM_2 = 2
SHAPE_DIM_3 = 3
NUM_NEG1 = -1
NUM_0 = 0
NUM_1 = 1
NUM_2 = 2
NUM_3 = 3
NUM_4 = 4
NUM_8 = 8
NUM_16 = 16
NUM_24 = 24
NUM_32 = 32
NUM_64 = 64
NUM_128 = 128
NUM_256 = 256
NUM_512 = 512
NUM_1024 = 1024
NUM_1127 = 1127
NUM_1536 = 1536
NUM_7168 = 7168
KEY_CUBE_NBUFFER_SETTING = "CUBE_NBUFFER_SETTING"
KEY_MG_COPYIN_UPPER_BOUND = "MG_COPYIN_UPPER_BOUND"
@dataclass
class MlaQuantInputs:
dequant_scale_x: Optional[pypto.Tensor] = None
dequant_scale_w_dq: Optional[pypto.Tensor] = None
dequant_scale_w_uq_qr: Optional[pypto.Tensor] = None
dequant_scale_w_dkv_kr: Optional[pypto.Tensor] = None
quant_scale_ckv: Optional[pypto.Tensor] = None
quant_scale_ckr: Optional[pypto.Tensor] = None
smooth_scales_cq: Optional[pypto.Tensor] = None
@dataclass
class RopeTileShapeConfig:
rope_2d: List[int]
rope_3d_vals: List[int]
rope_4d: List[int]
@dataclass
class MlaTileConfig:
tile_b: int = NUM_8
tile_s: int = NUM_1
rope: RopeTileShapeConfig = field(
default_factory=lambda: RopeTileShapeConfig(
rope_2d=[NUM_128, NUM_128],
rope_3d_vals=[NUM_32, NUM_128, NUM_128],
rope_4d=[NUM_16, NUM_128, NUM_128, NUM_128],
)
)
@dataclass
class MlaParams:
b: int
s1: int
n1: int
n2: int
h: int
q_lora_rank: int
kv_lora_rank: int
qk_rope_head_dim: int
qk_nope_head_dim: int
rope_dim: int
cache_mode: str
block_size: int
block_num: int
eps_cq: float
eps_ckv: float
tiles: MlaTileConfig
@staticmethod
def common() -> "MlaParams":
return MlaParams(
b=NUM_0,
s1=NUM_0,
n1=NUM_0,
n2=NUM_0,
h=NUM_7168,
q_lora_rank=NUM_1536,
kv_lora_rank=NUM_512,
qk_rope_head_dim=NUM_64,
qk_nope_head_dim=NUM_128,
rope_dim=NUM_64,
cache_mode="PA_BSND",
block_size=NUM_128,
block_num=NUM_0,
eps_cq=1e-5,
eps_ckv=1e-5,
tiles=MlaTileConfig(),
)
@dataclass
class MlaTensors:
x: pypto.Tensor
w_dq: pypto.Tensor
w_uq_qr: pypto.Tensor
w_uk: pypto.Tensor
w_dkv_kr: pypto.Tensor
gamma_cq: pypto.Tensor
gamma_ckv: pypto.Tensor
sin: pypto.Tensor
cos: pypto.Tensor
cache_index: pypto.Tensor
kv_cache: pypto.Tensor
kr_cache: pypto.Tensor
q_out: pypto.Tensor
q_rope_out: pypto.Tensor
kv_cache_out: pypto.Tensor
kr_cache_out: pypto.Tensor
rms_res: pypto.Tensor
@dataclass
class MlaArgs:
params: MlaParams
tensors: MlaTensors
quant: MlaQuantInputs
def quantize(
input_x: pypto.Tensor,
is_symmetry: bool = True,
has_smooth: bool = False,
smooth_factor: pypto.Tensor = None,
) -> Tuple[pypto.Tensor, pypto.Tensor]:
x = pypto.cast(input_x, pypto.DT_FP32)
if has_smooth:
x = x * smooth_factor
ones = pypto.tensor(1.0, pypto.DT_FP32)
z127 = pypto.tensor(127.0, pypto.DT_FP32)
z255 = pypto.tensor(255.0, pypto.DT_FP32)
eps = pypto.tensor(1e-12, pypto.DT_FP32)
if is_symmetry:
abs_x = pypto.abs(x)
max_val = pypto.amax(abs_x, -1)
scale_quant = z127 / max_val
out_fp32 = x / scale_quant
out_int_32 = pypto.cast(out_fp32, pypto.DT_INT32, pypto.CastMode.CAST_RINT)
out_half = pypto.cast(out_int_32, pypto.DT_FP16, pypto.CastMode.CAST_ROUND)
out_int_8 = pypto.cast(out_half, pypto.DT_INT8, pypto.CastMode.CAST_TRUNC, satmode=pypto.SaturationMode.ON)
scale_dequant = ones / scale_quant
return out_int_8, scale_dequant
else:
max_v = pypto.amax(x, -1)
min_v = pypto.amin(x, -1)
scale_dequant = pypto.maximum((max_v - min_v) / z255, eps)
scale_quant = ones / scale_dequant
out_fp32 = x * scale_quant
out_int_32 = pypto.cast(out_fp32, pypto.DT_INT32, pypto.CastMode.CAST_RINT)
out_half = pypto.cast(out_int_32, pypto.DT_FP16, pypto.CastMode.CAST_ROUND)
out_int_8 = pypto.cast(out_half, pypto.DT_INT8, pypto.CastMode.CAST_TRUNC, satmode=pypto.SaturationMode.ON)
return out_int_8, scale_dequant
def dequant(
dtype: pypto.DataType,
input_tensor: pypto.Tensor,
scale: pypto.Tensor,
w_scale: pypto.Tensor,
) -> pypto.Tensor:
out = pypto.cast(input_tensor, pypto.DT_FP32)
out = out * scale
out = out * w_scale
return pypto.cast(out, dtype)
def rotate_half(x: pypto.Tensor) -> pypto.Tensor:
shape = x.shape
nd = x.dim
assert nd >= 1
assert shape[nd - 1] % NUM_2 == 0
new_shape = list(shape)
new_shape[nd - 1] //= NUM_2
off1 = [0] * nd
off2 = [0] * nd
off2[nd - 1] = new_shape[nd - 1]
x1 = pypto.view(x, new_shape, off1)
x2 = pypto.view(x, new_shape, off2)
return pypto.concat([x2 * (-1.0), x1 + 0.0], -1)
def rope_2d(
x: pypto.Tensor, cos: pypto.Tensor, sin: pypto.Tensor, tiles: RopeTileShapeConfig
) -> pypto.Tensor:
assert x.dim == SHAPE_DIM_2 and cos.dim == SHAPE_DIM_2 and sin.dim == SHAPE_DIM_2
seq = x.shape[NUM_0]
d_r = x.shape[NUM_1]
x_dtype = x.dtype
pypto.set_vec_tile_shapes(tiles.rope_2d[0], tiles.rope_2d[1])
cast_x = pypto.cast(x, pypto.DT_FP32)
if x.dtype == pypto.DT_FP32:
cast_x = cast_x + 0.0
cast_cos = pypto.cast(cos, pypto.DT_FP32)
cast_sin = pypto.cast(sin, pypto.DT_FP32)
x_view = pypto.reshape(cast_x, [1, seq, d_r // NUM_2, NUM_2])
pypto.set_vec_tile_shapes(*tiles.rope_4d)
x_t = pypto.transpose(x_view, NUM_2, NUM_3)
x_flat = pypto.reshape(x_t, [seq, d_r])
pypto.set_vec_tile_shapes(tiles.rope_2d[0], tiles.rope_2d[1])
if not (x.shape[0] == cos.shape[0] and x.shape[1] == cos.shape[1]):
cast_cos[:] = pypto.expand(cast_cos, x.shape)
cast_sin[:] = pypto.expand(cast_sin, x.shape)
x_embed = x_flat * cast_cos + rotate_half(x_flat) * cast_sin
return pypto.cast(x_embed, x_dtype)
def rope_3d(
x: pypto.Tensor, cos: pypto.Tensor, sin: pypto.Tensor, tiles: RopeTileShapeConfig
) -> pypto.Tensor:
assert x.dim == SHAPE_DIM_3 and cos.dim == SHAPE_DIM_2 and sin.dim == SHAPE_DIM_2
pypto.set_vec_tile_shapes(NUM_1, NUM_32, NUM_128)
cast_x = pypto.cast(x, pypto.DT_FP32)
if x.dtype == pypto.DT_FP32:
cast_x = cast_x + 0.0
cast_cos = pypto.cast(cos, pypto.DT_FP32)
cast_sin = pypto.cast(sin, pypto.DT_FP32)
cast_cos = pypto.reshape(cast_cos, [x.shape[NUM_0], 1, x.shape[NUM_2]])
cast_sin = pypto.reshape(cast_sin, [x.shape[NUM_0], 1, x.shape[NUM_2]])
x_view = pypto.reshape(
cast_x, [x.shape[NUM_0], x.shape[NUM_1], x.shape[NUM_2] // NUM_2, NUM_2]
)
pypto.set_vec_tile_shapes(NUM_1, NUM_32, NUM_128, NUM_128)
x_t = pypto.transpose(x_view, NUM_2, NUM_3)
x_back = pypto.reshape(x_t, x.shape)
pypto.set_vec_tile_shapes(NUM_1, NUM_32, NUM_128, NUM_128)
x_embed = x_back * cast_cos + rotate_half(x_back) * cast_sin
return pypto.cast(x_embed, x.dtype)
def pre_compute_2d(
x_bs: pypto.Tensor, tens: MlaTensors, quant_inputs: MlaQuantInputs, eps_cq: float
) -> Tuple[pypto.Tensor, pypto.Tensor, pypto.Tensor]:
w_dq = tens.w_dq
w_uq_qr = tens.w_uq_qr
w_dkv_kr = tens.w_dkv_kr
gamma_cq = tens.gamma_cq
dq_w_scale = quant_inputs.dequant_scale_w_dq
dkv_w_scale = quant_inputs.dequant_scale_w_dkv_kr
uq_w_scale = quant_inputs.dequant_scale_w_uq_qr
smooth_scales_cq = quant_inputs.smooth_scales_cq
is_smooth = smooth_scales_cq is not None
is_quant_a = (dq_w_scale is not None) and (dkv_w_scale is not None)
is_quant_b = uq_w_scale is not None
bs = x_bs.shape[SHAPE_DIM_0]
q_rank = w_dq.shape[SHAPE_DIM_1]
dtype = x_bs.dtype
dtype_a = pypto.DT_INT32 if is_quant_a else dtype
dtype_b = pypto.DT_INT32 if is_quant_b else dtype
pypto.set_semantic_label("pre_reshape")
c0 = NUM_16
m = (min(NUM_32, bs) + c0 - 1) // c0 * c0
mv = min(NUM_8, bs)
q_a_proj = pypto.tensor()
if is_quant_a:
pypto.set_vec_tile_shapes(mv, q_rank)
pypto.set_cube_tile_shapes([m, m], [NUM_256, NUM_256], [NUM_256, NUM_256])
pypto.set_semantic_label("Quant_x")
quant_res = quantize(x_bs)
x_q, x_q_scale = quant_res[0], quant_res[1]
pypto.set_semantic_label("QuantMatmul_qa")
q_a_proj[:] = pypto.matmul(x_q, w_dq, dtype_a)
pypto.set_semantic_label("Dequant_qa")
q_a_proj[:] = dequant(dtype, q_a_proj, x_q_scale, dq_w_scale)
else:
pypto.set_cube_tile_shapes([m, m], [NUM_256, NUM_256], [NUM_64, NUM_64])
pypto.set_semantic_label("Matmul_qa")
q_a_proj[:] = pypto.matmul(x_bs, w_dq, dtype)
pypto.set_vec_tile_shapes(mv, q_rank)
pypto.set_semantic_label("RmsNorm_qa")
q_rms = pypto.rms_norm(q_a_proj, gamma_cq, eps_cq)
q_b_proj = pypto.tensor()
if is_quant_b:
pypto.set_vec_tile_shapes(mv, q_rank)
pypto.set_cube_tile_shapes([m, m], [NUM_256, NUM_256], [NUM_256, NUM_256])
pypto.set_semantic_label("Quant_qMmRes")
if is_smooth:
quant_res = quantize(q_rms, True, True, smooth_scales_cq)
else:
quant_res = quantize(q_rms, True, False)
q_q, q_q_scale = quant_res[0], quant_res[1]
pypto.set_semantic_label("QuantMatmul_qb")
q_b_proj[:] = pypto.matmul(q_q, w_uq_qr, dtype_b)
pypto.set_semantic_label("Dequant_qb")
q_b_proj[:] = dequant(dtype, q_b_proj, q_q_scale, uq_w_scale)
else:
pypto.set_cube_tile_shapes([m, m], [NUM_256, NUM_256], [NUM_64, NUM_64])
pypto.set_semantic_label("Matmul_qb")
q_b_proj[:] = pypto.matmul(q_rms, w_uq_qr, dtype)
compressed_kv = pypto.tensor()
if is_quant_a:
pypto.set_vec_tile_shapes(mv, q_rank)
pypto.set_cube_tile_shapes([m, m], [NUM_256, NUM_256], [NUM_256, NUM_256])
pypto.set_semantic_label("QuantMatmul_kva")
compressed_kv[:] = pypto.matmul(x_q, w_dkv_kr, dtype_a)
pypto.set_semantic_label("Dequant_kva")
compressed_kv[:] = dequant(dtype, compressed_kv, x_q_scale, dkv_w_scale)
else:
pypto.set_cube_tile_shapes([m, m], [NUM_256, NUM_256], [NUM_64, NUM_64])
pypto.set_semantic_label("Matmul_kva")
compressed_kv[:] = pypto.matmul(x_bs, w_dkv_kr, dtype)
return q_b_proj, compressed_kv, q_rms
def mla_prolog_compute(args: MlaArgs):
p = args.params
t = args.tensors
quant_inputs = args.quant
tiles = p.tiles
assert (
len(t.x.shape) == NUM_3
and len(t.w_uk.shape) == NUM_3
and len(t.sin.shape) == NUM_3
)
assert len(t.kv_cache.shape) == NUM_4 and len(t.kr_cache.shape) == NUM_4
assert p.cache_mode in ["PA_BSND", "PA_NZ"]
dtype = t.x.dtype
h = t.x.shape[SHAPE_DIM_2]
n = t.w_uk.shape[SHAPE_DIM_0]
q_lora_rank = t.w_dq.shape[SHAPE_DIM_1]
qk_nope_head_dim = t.w_uk.shape[SHAPE_DIM_1]
kv_lora_rank = t.w_uk.shape[SHAPE_DIM_2]
qk_rope_head_dim = t.sin.shape[SHAPE_DIM_2]
q_head_dim = qk_nope_head_dim + qk_rope_head_dim
block_num = t.kv_cache.shape[SHAPE_DIM_0]
block_size = t.kv_cache.shape[SHAPE_DIM_1]
n2 = t.kv_cache.shape[SHAPE_DIM_2]
assert qk_nope_head_dim == NUM_128 or qk_rope_head_dim == NUM_64
tile_b = tiles.tile_b
tile_s = tiles.tile_s
tile_bs = tile_b * tile_s
rope_cfg = p.tiles.rope
b = t.x.shape[SHAPE_DIM_0]
s = t.x.shape[SHAPE_DIM_1]
bs_loop = (b * s + tile_bs - 1) // tile_bs
x_2d = pypto.tensor([b * s, h], t.x.dtype, "x2D")
cos_2d = pypto.tensor([b * s, qk_rope_head_dim], t.cos.dtype, "cos2D")
sin_2d = pypto.tensor([b * s, qk_rope_head_dim], t.sin.dtype, "sin2D")
k_cache_index_2d = pypto.tensor([b * s, 1], t.cache_index.dtype, "kCacheIndex2D")
for _ in pypto.loop(0, 1, 1, name="LOOP_MLA_RESHAPE", idx_name="batch_id"):
x_2d[:] = pypto.reshape(t.x, [b * s, h], inplace=True)
cos_2d[:] = pypto.reshape(t.cos, [b * s, qk_rope_head_dim], inplace=True)
sin_2d[:] = pypto.reshape(t.sin, [b * s, qk_rope_head_dim], inplace=True)
k_cache_index_2d[:] = pypto.reshape(t.cache_index, [b * s, 1], inplace=True)
kv_cache_res = pypto.tensor(
[block_num * block_size * n2, kv_lora_rank], t.kv_cache.dtype, "kvCacheRes"
)
kr_cache_res = pypto.tensor(
[block_num * block_size * n2, qk_rope_head_dim], t.kr_cache.dtype, "krCacheRes"
)
for _ in pypto.loop(0, 1, 1, name="MLA_RESHAPE", idx_name="unused_idx"):
kv_cache_res[:] = pypto.reshape(
t.kv_cache, [block_num * block_size * n2, kv_lora_rank], inplace=True
)
kr_cache_res[:] = pypto.reshape(
t.kr_cache, [block_num * block_size * n2, qk_rope_head_dim], inplace=True
)
for bs_idx in pypto.loop(0, bs_loop, 1, name="MLA_BS_Loop", idx_name="bs_idx"):
bs_offset = bs_idx * tile_bs
output_offset = [bs_offset, 0, 0]
pypto.set_vec_tile_shapes(tile_bs, NUM_128)
x_view = pypto.view(x_2d, [tile_bs, h], [bs_offset, 0])
x_view[:] = pypto.cast(pypto.cast(x_view, pypto.DT_FP32), dtype)
q, kv_tmp, q_rms = pre_compute_2d(
x_view,
t,
quant_inputs,
p.eps_cq,
)
q_tmp = pypto.reshape(q, [tile_bs, n, q_head_dim])
pypto.set_semantic_label("Prepare_qNope")
q_nope = pypto.view(q_tmp, [tile_bs, n, qk_nope_head_dim], [0, 0, 0])
tile_shape = [min(32, tile_bs), 1, qk_nope_head_dim]
pypto.set_vec_tile_shapes(*tile_shape)
q_nope_trans = pypto.transpose(q_nope, 0, 1)
pypto.set_semantic_label("pre_reshape")
c0 = NUM_16
m = (min(NUM_32, tile_bs) + c0 - 1) // c0 * c0
pypto.set_semantic_label("Matmul_qNope_wUk")
pypto.set_cube_tile_shapes([m, m], [NUM_128, NUM_128], [NUM_128, NUM_128])
q_nope_new = pypto.matmul(q_nope_trans, t.w_uk, dtype)
pypto.set_semantic_label("queryOut")
tile_shape = [NUM_1, min(NUM_32, tile_bs), kv_lora_rank]
pypto.set_vec_tile_shapes(*tile_shape)
q_nope_new_trans = pypto.transpose(q_nope_new, 0, 1)
pypto.set_semantic_label("Assemble_queryOut")
pypto.set_vec_tile_shapes(NUM_1, NUM_32, NUM_128)
pypto.assemble(q_nope_new_trans, output_offset, t.q_out)
q_pe_view = pypto.view(
q_tmp, [tile_bs, n, qk_rope_head_dim], [0, 0, qk_nope_head_dim]
)
cos_2d[:] = pypto.view(cos_2d, [tile_bs, qk_rope_head_dim], [bs_offset, 0])
sin_2d[:] = pypto.view(sin_2d, [tile_bs, qk_rope_head_dim], [bs_offset, 0])
q_rope_view = rope_3d(q_pe_view, cos_2d, sin_2d, rope_cfg)
pypto.set_semantic_label("Assemble_qRope")
pypto.set_vec_tile_shapes(NUM_1, NUM_32, NUM_64)
pypto.assemble(q_rope_view, output_offset, t.q_rope_out)
pypto.set_vec_tile_shapes(NUM_2, NUM_512)
pypto.set_semantic_label("RotaryPosEmb")
k_pe_view = pypto.view(
kv_tmp, [tile_bs, qk_rope_head_dim], [0, kv_lora_rank]
)
k_rope_view = rope_2d(k_pe_view, cos_2d, sin_2d, rope_cfg)
k_rope_res = pypto.reshape(k_rope_view, [tile_bs, 1, 1, qk_rope_head_dim])
pypto.set_semantic_label("ScatterUpdate_krCache")
tile_shape = [NUM_1, qk_rope_head_dim]
pypto.set_vec_tile_shapes(*tile_shape)
index = pypto.view(k_cache_index_2d, [tile_bs, 1], [bs_offset, 0])
pypto.set_vec_tile_shapes(NUM_4, NUM_128, NUM_128, NUM_128)
t.kr_cache_out[:] = pypto.scatter_update(t.kr_cache, -2, index, k_rope_res)
compressed_kv = pypto.view(kv_tmp, [tile_bs, kv_lora_rank], [0, 0])
tile_shape = [NUM_2, NUM_512]
pypto.set_semantic_label("RmsNorm_compressedKv")
pypto.set_vec_tile_shapes(*tile_shape)
k_nope = pypto.rms_norm(compressed_kv, t.gamma_ckv, p.eps_ckv)
k_nope[:] = pypto.reshape(k_nope, [tile_bs, 1, 1, kv_lora_rank])
pypto.set_semantic_label("ScatterUpdate_kvCache")
pypto.set_vec_tile_shapes(NUM_4, NUM_128, NUM_128, NUM_512)
t.kv_cache_out[:] = pypto.scatter_update(t.kv_cache, -2, index, k_nope)
pypto.set_vec_tile_shapes(tile_bs, q_lora_rank)
rms_3d = pypto.cast(pypto.cast(q_rms, pypto.DT_FP32), dtype)
pypto.assemble(rms_3d, [bs_offset, 0], t.rms_res)
def mla_prolog(args: MlaArgs):
inp = [
args.tensors.x,
args.tensors.w_dq,
args.tensors.w_uq_qr,
args.tensors.w_uk,
args.tensors.w_dkv_kr,
args.tensors.gamma_cq,
args.tensors.gamma_ckv,
args.tensors.sin,
args.tensors.cos,
args.tensors.cache_index,
args.tensors.kv_cache,
args.tensors.kr_cache,
]
out = [
args.tensors.q_out,
args.tensors.q_rope_out,
args.tensors.kv_cache_out,
args.tensors.kr_cache_out,
args.tensors.rms_res,
]
with pypto.function("MLAProlog", *inp, *out):
mla_prolog_compute(args)
@dataclass
class MlaBuildConfig:
b: int = NUM_4
s1: int = NUM_2
n1: int = NUM_128
n2: int = NUM_1
h: int = NUM_7168
q_lora_rank: int = NUM_1536
kv_lora_rank: int = NUM_512
qk_rope_head_dim: int = NUM_64
qk_nope_head_dim: int = NUM_128
rope_dim: int = NUM_64
cache_mode: str = "PA_BSND"
block_size: int = NUM_128
block_num: int = NUM_1127
eps_cq: float = 1e-5
eps_ckv: float = 1e-5
tile_b_override: int = NUM_NEG1
tile_s: int = NUM_1
rope_2d: List[int] = field(default_factory=lambda: [NUM_128, NUM_128])
rope_3d_vals: List[int] = field(default_factory=lambda: [NUM_32, NUM_128, NUM_128])
rope_4d: List[int] = field(
default_factory=lambda: [NUM_16, NUM_128, NUM_128, NUM_128]
)
def setup_codegen_passes():
pypto.set_pass_options(
cube_l1_reuse_setting={-1: NUM_4},
cube_nbuffer_setting={NUM_3: NUM_4},
)
def build_args(cfg: MlaBuildConfig):
tile_b = (
cfg.tile_b_override
if cfg.tile_b_override != NUM_NEG1
else (NUM_8 if cfg.b == NUM_24 else cfg.b)
)
tiles = MlaTileConfig(
tile_b=tile_b,
tile_s=cfg.tile_s,
rope=RopeTileShapeConfig(
rope_2d=cfg.rope_2d,
rope_3d_vals=cfg.rope_3d_vals,
rope_4d=cfg.rope_4d,
),
)
params = MlaParams(
b=cfg.b,
s1=cfg.s1,
n1=cfg.n1,
n2=cfg.n2,
h=cfg.h,
q_lora_rank=cfg.q_lora_rank,
kv_lora_rank=cfg.kv_lora_rank,
qk_rope_head_dim=cfg.qk_rope_head_dim,
qk_nope_head_dim=cfg.qk_nope_head_dim,
rope_dim=cfg.rope_dim,
cache_mode=cfg.cache_mode,
block_size=cfg.block_size,
block_num=cfg.block_num,
eps_cq=cfg.eps_cq,
eps_ckv=cfg.eps_ckv,
tiles=tiles,
)
d_fp16 = pypto.DT_FP16
d_int32 = pypto.DT_INT32
x = pypto.tensor([cfg.b, cfg.s1, cfg.h], d_fp16, "x")
w_dq = pypto.tensor(
[cfg.h, cfg.q_lora_rank], d_fp16, "wDq", pypto.TileOpFormat.TILEOP_ND
)
w_uq_qr = pypto.tensor(
[cfg.q_lora_rank, cfg.n1 * (cfg.qk_nope_head_dim + cfg.qk_rope_head_dim)],
d_fp16,
"wUqQr",
pypto.TileOpFormat.TILEOP_ND,
)
w_dkv_kr = pypto.tensor(
[cfg.h, cfg.kv_lora_rank + cfg.qk_rope_head_dim],
d_fp16,
"wDkvKr",
pypto.TileOpFormat.TILEOP_ND,
)
w_uk = pypto.tensor(
[cfg.n1, cfg.qk_nope_head_dim, cfg.kv_lora_rank],
d_fp16,
"wUk",
pypto.TileOpFormat.TILEOP_ND,
)
gamma_cq = pypto.tensor(
[cfg.q_lora_rank], d_fp16, "gammaCq", pypto.TileOpFormat.TILEOP_ND
)
gamma_ckv = pypto.tensor(
[cfg.kv_lora_rank], d_fp16, "gammaCkv", pypto.TileOpFormat.TILEOP_ND
)
cos = pypto.tensor([cfg.b, cfg.s1, cfg.qk_rope_head_dim], d_fp16, "cos")
sin = pypto.tensor([cfg.b, cfg.s1, cfg.qk_rope_head_dim], d_fp16, "sin")
cache_index = pypto.tensor([cfg.b, cfg.s1], d_int32, "cacheIndex")
kv_cache = pypto.tensor(
[cfg.block_num, cfg.block_size, cfg.n2, cfg.kv_lora_rank], d_fp16, "kvCache"
)
kr_cache = pypto.tensor(
[cfg.block_num, cfg.block_size, cfg.n2, cfg.qk_rope_head_dim], d_fp16, "krCache"
)
q_out = pypto.tensor(
[cfg.b * cfg.s1, cfg.n1, cfg.kv_lora_rank], d_fp16, "queryNopeOut"
)
q_rope_out = pypto.tensor(
[cfg.b * cfg.s1, cfg.n1, cfg.qk_rope_head_dim], d_fp16, "queryRopeOut"
)
kv_cache_out = pypto.tensor(
[cfg.block_num, cfg.block_size, cfg.n2, cfg.kv_lora_rank], d_fp16, "kvCacheOut"
)
kr_cache_out = pypto.tensor(
[cfg.block_num, cfg.block_size, cfg.n2, cfg.qk_rope_head_dim],
d_fp16,
"krCacheOut",
)
rms_res = pypto.tensor(
[cfg.b * cfg.s1, cfg.q_lora_rank],
d_fp16,
"rmsRes",
)
tensors = MlaTensors(
x=x,
w_dq=w_dq,
w_uq_qr=w_uq_qr,
w_uk=w_uk,
w_dkv_kr=w_dkv_kr,
gamma_cq=gamma_cq,
gamma_ckv=gamma_ckv,
sin=sin,
cos=cos,
cache_index=cache_index,
kv_cache=kv_cache,
kr_cache=kr_cache,
q_out=q_out,
q_rope_out=q_rope_out,
kv_cache_out=kv_cache_out,
kr_cache_out=kr_cache_out,
rms_res=rms_res,
)
return MlaArgs(params=params, tensors=tensors, quant=MlaQuantInputs())
@duration_estimate(12)
def test_dynamic_mla_prolog():
args = build_args(MlaBuildConfig())
setup_codegen_passes()
mla_prolog(args)
assert True
if __name__ == "__main__":
test_dynamic_mla_prolog()