'''
'''
import os
import math
import time
import logging
from pathlib import Path
import torch
import torch_npu
import pytest
import pypto
from mla_prolog_quant_impl import mla_prolog_quant_p, mla_prolog_quant_d, MlaTileConfig
from utils.compare import compare
def prep_env():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
torch_npu.npu.config.allow_internal_format = True
def rms_norm(x, gamma):
x_dtype = x.dtype
mean_coff = 1.0 / x.shape[-1]
x_f32 = x.to(torch.float32)
square = x_f32 * x_f32
mean_res = square * mean_coff
reduce_sum = torch.sum(mean_res, dim=-1, keepdims=True)
reduce_sqrt = torch.sqrt(reduce_sum)
res_div = x_f32 / reduce_sqrt
res = res_div * gamma
if x_dtype != torch.float32:
res = res.to(x_dtype)
return res
def scatter_update(inputs, axis):
cache, key_states, indices = inputs
block_number, block_size, n2, d = cache.shape
res = cache.reshape(block_number * block_size * n2, d)
b, s1 = indices.shape
if axis == -2:
for b_i in range(b):
for s1_i in range(s1):
index_value = indices[b_i][s1_i]
res[index_value][:] = key_states[b_i * s1 + s1_i][:]
return res.reshape(block_number, block_size, n2, d)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.concatenate((-x2, x1), dim=-1)
def apply_rotary_pos_emb_v2(q, k, cos, sin, unsqueeze_dim=2):
input_dtype = q.dtype
if input_dtype != torch.float32:
q = q.to(torch.float32)
k = k.to(torch.float32)
if cos.dtype != torch.float32:
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
cos = torch.unsqueeze(cos, dim=unsqueeze_dim)
sin = torch.unsqueeze(sin, dim=unsqueeze_dim)
b, s, h, d = q.shape
q = q.reshape(b, s, h, d // 2, 2).permute(0, 1, 2, 4, 3).reshape(b, s, h, d)
b, s, h, d = k.shape
k = k.reshape(b, s, h, d // 2, 2).permute(0, 1, 2, 4, 3).reshape(b, s, h, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
if input_dtype != torch.float32:
q_embed, k_embed = q_embed.to(input_dtype), k_embed.to(input_dtype)
return q_embed, k_embed
def quant(input_t, is_pertoken: bool = True, has_smooth=False, smooth_cq=None):
input_fp32 = input_t.to(torch.float32)
if has_smooth:
input_fp32 = input_fp32 * smooth_cq
abs_res = torch.abs(input_fp32)
reduce_idx = -1
if not is_pertoken:
reduce_idx = -2
logging.debug("This PerChannel Quant!!")
max_value = torch.max(abs_res, dim=reduce_idx, keepdims=True)[0]
scale_quant = 127 / max_value
out_fp32 = input_fp32 * scale_quant
out_int32 = torch.round(out_fp32).to(torch.int32)
out_fp16 = out_int32.to(torch.float16)
out_int8 = torch.trunc(out_fp16).to(torch.int8)
scale_dequant = 1 / scale_quant
return out_int8, scale_dequant
def tensor_to_file(t: torch.Tensor, output: Path):
with open(str(output), "wb") as f:
dtype = t.dtype
if dtype == torch.bfloat16:
dtype = torch.int16
for each in t:
f.write(each.view(dtype).cpu().numpy().tobytes())
def mla_prolog_quant_v32_compute(inputs):
dtype = inputs.get("dtype")
is_quant_a = inputs.get("is_quant_a")
is_quant_b = inputs.get("is_quant_b")
has_smooth = inputs.get("has_smooth")
cache_mode = inputs.get("cache_mode")
gamma_cq = inputs.get("gamma_cq")
gamma_ckv = inputs.get("gamma_ckv")
x = inputs.get("x")
w_dq = inputs.get("w_dq")
w_uqqr = inputs.get("w_uqqr")
w_uk = inputs.get("w_uk")
w_dkvkr = inputs.get("w_dkvkr")
cos = inputs.get("cos")
sin = inputs.get("sin")
kv_cache = inputs.get("kv_cache")
kr_cache = inputs.get("kr_cache")
kv_quant_scale_cache = None
if is_quant_b:
kv_quant_scale_cache = inputs.get("kv_quant_scale_cache")
cache_index = inputs.get("cache_index")
if is_quant_a:
w_qa_scale = inputs.get("w_qa_scale")
w_kva_scale = inputs.get("w_kva_scale")
if is_quant_b:
w_qb_scale = inputs.get("w_qb_scale")
if has_smooth:
smooth_cq = inputs.get("smooth_cq")
b, s, h = x.shape
qk_rope_head_dim = cos.shape[2]
n, qk_nope_head_dim, kv_lora_rank = w_uk.shape
q_head_dim = qk_nope_head_dim + qk_rope_head_dim
""" q """
x_2d = x.reshape(b * s, h)
if is_quant_a:
x_2d_quant, x_2d_scale_dequant = quant(x_2d, True)
q_a_proj = torch.matmul(x_2d_quant.to(torch.int32), w_dq.to(torch.int32))
""" dequant """
q_a_proj_fp32 = q_a_proj.to(torch.float32)
q_a_proj_fp32_dequant = q_a_proj_fp32 * x_2d_scale_dequant
q_a_proj = q_a_proj_fp32_dequant * w_qa_scale
else:
q_a_proj = torch.matmul(x_2d.to(torch.float32), w_dq.to(torch.float32))
q_a_proj = q_a_proj.to(torch.bfloat16)
q_a_layernorm = rms_norm(q_a_proj, gamma_cq)
q_a_layernorm_scale_dequant = None
if is_quant_b:
if has_smooth:
q_a_layernorm, q_a_layernorm_scale_dequant = quant(q_a_layernorm, True, True, smooth_cq)
else:
q_a_layernorm, q_a_layernorm_scale_dequant = quant(q_a_layernorm, True)
q_b_proj = torch.matmul(q_a_layernorm.to(torch.int32).cpu(), \
w_uqqr.to(torch.int32).cpu()).to(q_a_layernorm.device)
""" dequant """
q_b_proj_fp32 = q_b_proj.to(torch.float32)
q_b_proj_fp32_dequant = q_b_proj_fp32 * q_a_layernorm_scale_dequant
q_b_proj = q_b_proj_fp32_dequant * w_qb_scale
else:
q_b_proj = torch.matmul(q_a_layernorm.to(torch.float32), w_uqqr.to(torch.float32))
q_b_proj = q_b_proj.to(dtype)
q_reshape = q_b_proj.reshape(b, s, n, q_head_dim)
q_nope = q_reshape[:, :, :, 0:qk_nope_head_dim]
q_nope_r = q_nope.reshape(b * s, n, qk_nope_head_dim)
q_nope_t = q_nope_r.permute(1, 0, 2)
q_nope_new = torch.matmul(q_nope_t.to(torch.float32), w_uk.to(torch.float32))
q_nope_new = q_nope_new.to(dtype)
q_nope_new_t = q_nope_new.permute(1, 0, 2)
q_out = q_nope_new_t.reshape(b, s, n, kv_lora_rank)
""" kv """
if is_quant_a:
x_2d_quant, x_2d_scale_dequant = quant(x_2d, True)
kv_a_proj = torch.matmul(x_2d_quant.to(torch.int32), w_dkvkr.to(torch.int32))
""" dequant """
kv_a_proj_fp32 = kv_a_proj.to(torch.float32)
kv_a_proj_fp32_dequant = kv_a_proj_fp32 * x_2d_scale_dequant
kv_a_proj = kv_a_proj_fp32_dequant * w_kva_scale
else:
kv_a_proj = torch.matmul(x_2d.to(torch.float32),
w_dkvkr.to(torch.float32))
kv_a_proj = kv_a_proj.to(dtype)
kv_reshape = kv_a_proj.reshape(b, s, kv_lora_rank + qk_rope_head_dim)
compressed_kv = kv_reshape[:, :, 0:kv_lora_rank]
compressed_kv_norm = rms_norm(compressed_kv, gamma_ckv)
compressed_kv_quant_scale = None
if is_quant_b:
compressed_kv_norm_split = compressed_kv_norm.reshape(b * s, 4, kv_lora_rank // 4)
compressed_kv_norm, compressed_kv_quant_scale = quant(compressed_kv_norm_split, True)
compressed_kv_quant_scale = compressed_kv_quant_scale.reshape(b, s, 1, 4)
compressed_kv_r = compressed_kv_norm.reshape(b, s, 1, kv_lora_rank)
k_nope = compressed_kv_r.reshape(b * s * 1, kv_lora_rank)
""" RoPE """
q_pe = q_reshape[:, :, :, qk_nope_head_dim:]
k_pe = kv_reshape[:, :, kv_lora_rank:]
k_pe_r = k_pe.reshape(b, s, 1, qk_rope_head_dim)
q_embed, k_embed = apply_rotary_pos_emb_v2(q_pe, k_pe_r, cos, sin, 2)
k_embed_r = k_embed.reshape(b * 1 * s, qk_rope_head_dim)
""" kv_cache output, [b,1,s2,kv_lora_rank] """
kv_cache_tmp = kv_cache.clone()
kv_cache_out = scatter_update([kv_cache_tmp, k_nope, cache_index], -2)
""" kr_cache output, [b,1,s2,qk_rope_head_dim] """
kr_cache_tmp = kr_cache.clone()
kr_cache_out = scatter_update([kr_cache_tmp, k_embed_r, cache_index], -2)
if is_quant_b:
compressed_kv_quant_scale = compressed_kv_quant_scale.reshape(-1, 4)
kv_quant_scale_cache_tmp = kv_quant_scale_cache.clone()
kv_quant_scale_cache_out = \
scatter_update([kv_quant_scale_cache_tmp, compressed_kv_quant_scale, cache_index], -2)
else:
kv_quant_scale_cache_out = None
return q_out, q_embed, q_a_layernorm, q_a_layernorm_scale_dequant, kv_cache_out, \
kr_cache_out, kv_quant_scale_cache_out
def gen_block_table(act_seq, block_size, s1, need_indices=False):
b = act_seq.shape[0]
block_num = 0
block_num_each = []
max_kv = max(act_seq)
for cur_s in act_seq:
cur_block_num = math.ceil(cur_s / block_size)
block_num_each.append(cur_block_num)
block_num += cur_block_num
block_table_shape = [b, math.ceil(max_kv / block_size)]
block_idx_list = torch.arange(0, block_num, 1)
block_idx_list = block_idx_list[torch.randperm(block_idx_list.size(0))].to(torch.int32)
block_table = -torch.ones(block_table_shape, dtype=torch.int32)
block_idx = 0
block_table_bidx = 0
for cur_block in block_num_each:
for j in range(cur_block):
block_table[block_table_bidx, j] = block_idx_list[block_idx]
block_idx += 1
block_table_bidx += 1
if need_indices:
cache_index = -torch.ones((b, s1), dtype=torch.int64)
for i in range(b):
cur_act = act_seq[i]
for j in range(s1):
pos = cur_act - s1 + j
block_idx_in_seq = pos // block_size
global_block_id = block_table[i, block_idx_in_seq]
offset_in_block = pos % block_size
global_index = global_block_id * block_size + offset_in_block
cache_index[i, j] = global_index
else:
cache_index = None
if need_indices:
return block_num, block_table, cache_index
else:
return block_num, block_table, cache_index
def gen_mla_prolog_quant_v32_input_data(params, dtypes, actual_seq, is_quant=(False, False),
has_smooth=False, block_size=128, cache_mode="BSND"):
dtype, w_dtype = dtypes
is_quant_a, is_quant_b = is_quant
b = params.get("b")
s = params.get("s")
s1 = params.get("s1")
h = params.get("h")
n = params.get("n1")
q_lora_rank = params.get("q_lora_rank")
qk_nope_head_dim = params.get("qk_nope_head_dim")
qk_rope_head_dim = params.get("qk_rope_head_dim")
kv_lora_rank = params.get("kv_lora_rank")
block_num, block_table, cache_index = gen_block_table(actual_seq, block_size, s1, need_indices=True)
skv_max = actual_seq.max()
q_head_dim = qk_nope_head_dim + qk_rope_head_dim
x_shape = [b, s, h]
w_qa_shape = [h, q_lora_rank]
w_qb_shape = [q_lora_rank, n * q_head_dim]
w_kv_a_shape = [h, kv_lora_rank + qk_rope_head_dim]
w_kv_b_k_shape = [n, qk_nope_head_dim, kv_lora_rank]
gamma_cq_shape = [q_lora_rank]
gamma_ckv_shape = [kv_lora_rank]
cos_shape = [b, s, qk_rope_head_dim]
kv_bsnd_shape = [b, skv_max, 1, kv_lora_rank + qk_rope_head_dim]
kv_cache_shape = [block_num, block_size, 1, kv_lora_rank]
kr_cache_shape = [block_num, block_size, 1, qk_rope_head_dim]
kv_quant_scale_cache_shape = [block_num, block_size, 1, 4]
smooth_cq_shape = [1, q_lora_rank]
res = [None] * 17
x = torch.empty(x_shape).uniform_(-1, 1).to(dtype)
res[0] = x
w_dq = torch.empty(w_qa_shape).uniform_(-0.1, 0.1).to(w_dtype)
w_uqqr = torch.empty(w_qb_shape).uniform_(-0.1, 0.1).to(w_dtype)
w_dkvkr = torch.empty(w_kv_a_shape).uniform_(-0.1, 0.1).to(w_dtype)
res[4] = dict()
if is_quant_a:
w_dq, w_qa_scale = quant(w_dq, False)
w_dkvkr, w_kva_scale = quant(w_dkvkr, False)
res[4]["w_dq"] = w_qa_scale
res[4]["w_dkvkr"] = w_kva_scale
if is_quant_b:
w_uqqr, w_qb_scale = quant(w_uqqr, False)
res[4]["w_uqqr"] = w_qb_scale
if has_smooth:
smooth_cq = torch.empty(smooth_cq_shape).uniform_(-1, 1).to(torch.float32)
res[3] = smooth_cq
res[1] = w_dq
res[2] = w_uqqr
res[5] = w_dkvkr
w_uk = torch.empty(w_kv_b_k_shape).uniform_(-0.1, 0.1).to(w_dtype)
res[6] = w_uk
gamma_cq = torch.empty(gamma_cq_shape).uniform_(-1, 1).to(dtype)
gamma_ckv = torch.empty(gamma_ckv_shape).uniform_(-1, 1).to(dtype)
res[7] = gamma_cq
res[8] = gamma_ckv
cos = torch.empty(cos_shape).uniform_(-0.1, 0.1).to(dtype)
sin = torch.empty(cos_shape).uniform_(-0.1, 0.1).to(dtype)
res[9] = cos
res[10] = sin
res[11] = cache_index
k_bsnd = torch.empty(kv_bsnd_shape).uniform_(-1, 1).to(dtype)
per_batch_max_num = math.ceil(skv_max / block_size)
k_tensor_bsnd = torch.zeros((b, per_batch_max_num * block_size, 1, kv_lora_rank + qk_rope_head_dim)).to(dtype)
k_tensor_bsnd[:, :k_bsnd.shape[1], :, :] = k_bsnd[:, :, :, :]
k_cache_tensor = torch.zeros([block_num, block_size, 1, kv_lora_rank + qk_rope_head_dim]).to(dtype)
for b_idx in range(b):
for block_i, kv_cache_blk_id in enumerate(block_table[b_idx]):
block_offset = block_i * block_size
if kv_cache_blk_id == -1:
continue
else:
k_cache_tensor[kv_cache_blk_id, 0:block_size, :, :] = k_tensor_bsnd[
b_idx, block_offset:(block_offset + block_size), :, :]
kv_cache = k_cache_tensor[:, :, :, : kv_lora_rank]
kr_cache = k_cache_tensor[:, :, :, kv_lora_rank:]
kv_quant_scale_cache = None
if is_quant_b:
kv_cache_split = kv_cache.reshape(-1, 4, kv_lora_rank // 4)
kv_cache, kv_quant_scale_cache = quant(kv_cache_split, True)
kv_cache = kv_cache.reshape(kv_cache_shape)
kv_quant_scale_cache = kv_quant_scale_cache.reshape(kv_quant_scale_cache_shape)
res[12] = kv_cache
res[13] = kr_cache
res[14] = kv_quant_scale_cache
res[15] = block_num
res[16] = block_table
return res
def gen_mla_prolog_quant_v32_data(params, dtypes, actual_seq, is_quant=(False, False),
has_smooth=False, block_size=128, cache_mode="BSND"):
dtype, w_dtype = dtypes
x, w_dq, w_uqqr, smooth_cq, scale_data, w_dkvkr, w_uk, gamma_cq, gamma_ckv, cos, sin, kv_len, \
kv_cache, kr_cache, kv_quant_scale_cache, block_num, block_table = \
gen_mla_prolog_quant_v32_input_data(params, dtypes, actual_seq, is_quant, has_smooth,
block_size, cache_mode)
is_quant_a, is_quant_b = is_quant
inputs = {"dtype": dtype, "is_quant_a": is_quant_a, "is_quant_b": is_quant_b, "has_smooth": has_smooth}
inputs["cache_mode"] = cache_mode
inputs["gamma_cq"] = gamma_cq
inputs["gamma_ckv"] = gamma_ckv
inputs["x"] = x
inputs["w_dq"] = w_dq
inputs["w_uqqr"] = w_uqqr
inputs["w_uk"] = w_uk
inputs["w_dkvkr"] = w_dkvkr
inputs["cos"] = cos
inputs["sin"] = sin
inputs["kv_cache"] = kv_cache
inputs["kr_cache"] = kr_cache
inputs["kv_quant_scale_cache"] = kv_quant_scale_cache
inputs["cache_index"] = kv_len
if is_quant_a:
inputs["w_qa_scale"] = scale_data["w_dq"]
inputs["w_kva_scale"] = scale_data["w_dkvkr"]
if is_quant_b:
inputs["w_qb_scale"] = scale_data["w_uqqr"]
if has_smooth:
inputs["smooth_cq"] = smooth_cq
if torch_npu.npu.is_available():
for key, value in inputs.items():
if isinstance(value, torch.Tensor):
inputs[key] = value.npu()
q_out, q_embed, rms_norm_out, rms_norm_scale, kv_cache_out, kr_cache_out, \
kv_quant_scale_cache_out = mla_prolog_quant_v32_compute(inputs)
outputs = {"q_golden": q_out, "q_rope": q_embed, "kr_golden": kr_cache_out, "kv_golden": kv_cache_out}
outputs["kv_quant_scale_cache_golden"] = kv_quant_scale_cache_out
outputs["rms_norm_golden"] = rms_norm_out
outputs["rms_norm_scale_golden"] = rms_norm_scale
return inputs, outputs
def convert_pypto_to_torch_type(pypto_type):
if pypto_type == pypto.DT_INT8:
return torch.int8
elif pypto_type == pypto.DT_INT32:
return torch.int32
elif pypto_type == pypto.DT_FP32:
return torch.float32
elif pypto_type == pypto.DT_FP16:
return torch.float16
elif pypto_type == pypto.DT_BF16:
return torch.bfloat16
else:
raise ValueError(f"Unsupported pypto.DataType: {pypto_type}")
def mla_prolog_quant_v32(params, input_tensors, golden_data, dtype, w_dtype, is_quant_a, \
is_quant_b, nz, tile_config, cache_mode, is_p):
d_type = pypto.DT_FP16 if dtype == pypto.DT_FP16 else pypto.DT_BF16
if is_quant_a and w_dtype == pypto.DT_INT8:
dtype_qa = pypto.DT_INT8
else:
dtype_qa = dtype
if is_quant_b and w_dtype == pypto.DT_INT8:
dtype_qb = pypto.DT_INT8
else:
dtype_qb = dtype
dtype_kv_quant = dtype_qb
if is_quant_a:
w_dtype_a = w_dtype
else:
w_dtype_a = dtype
if is_quant_b:
w_dtype_b = w_dtype
kv_dtype = pypto.DT_INT8
else:
w_dtype_b = dtype
kv_dtype = dtype
b = params['b']
s = params['s']
t = b * s
s2 = params['s2']
n1 = params['n1']
n2 = 1
h = params['h']
q_lora_rank = params['q_lora_rank']
qk_nope_head_dim = params['qk_nope_head_dim']
qk_rope_head_dim = params['qk_rope_head_dim']
kv_lora_rank = params["kv_lora_rank"]
block_size = params['block_size']
q_head_dim = qk_nope_head_dim + qk_rope_head_dim
token_x_shape = [t, h]
w_dq_shape = [h, q_lora_rank]
w_uq_qr_shape = [q_lora_rank, n1 * q_head_dim]
dequant_scale_w_uq_qr_shape = [n1 * q_head_dim, 1]
w_dkv_kr_shape = [h, kv_lora_rank + qk_rope_head_dim]
w_uk_shape = [n1, qk_nope_head_dim, kv_lora_rank]
rope_cos_shape = [t, qk_rope_head_dim]
rmsnorm_gamma_cq_shape = [q_lora_rank]
rmsnorm_gamma_ckv_shape = [kv_lora_rank]
cache_index_shape = [t]
block_num = b * ((s2 + block_size - 1) // block_size)
kv_cache_shape = [block_num, block_size, n2, kv_lora_rank]
kr_cache_shape = [block_num, block_size, n2, qk_rope_head_dim]
k_scale_cache_shape = [block_num, block_size, n2, 4]
kv_cache_out_shape = [block_num, block_size, n2, kv_lora_rank]
kr_cache_out_shape = [block_num, block_size, n2, qk_rope_head_dim]
k_scale_cache_out_shape = [block_num, block_size, n2, 4]
q_nope_out_shape = [t, n1, kv_lora_rank]
q_rope_out_shape = [t, n1, qk_rope_head_dim]
q_norm_out_shape = [t, q_lora_rank]
q_norm_scale_out_shape = [t, 1]
golden1 = golden_data["q_golden"].reshape(q_nope_out_shape)
golden2 = golden_data["q_rope"] .reshape(q_rope_out_shape)
golden3 = golden_data["kv_golden"].reshape(kv_cache_out_shape)
golden4 = golden_data["kr_golden"].reshape(kr_cache_out_shape)
if is_quant_b:
golden5 = golden_data["kv_quant_scale_cache_golden"].reshape(k_scale_cache_out_shape)
golden6 = golden_data["rms_norm_golden"].reshape(q_norm_out_shape)
if is_quant_b:
golden7 = golden_data["rms_norm_scale_golden"].reshape(q_norm_scale_out_shape)
output_q_norm_data = torch.empty(q_norm_out_shape, dtype=convert_pypto_to_torch_type(dtype_kv_quant)).npu()
output_q_norm_scale_data = torch.empty(q_norm_scale_out_shape, dtype=torch.float32).npu()
output_q_nope_data = torch.empty(q_nope_out_shape, dtype=convert_pypto_to_torch_type(d_type)).npu()
output_q_rope_data = torch.empty(q_rope_out_shape, dtype=convert_pypto_to_torch_type(d_type)).npu()
output_kv_cache_data = input_tensors["kv_cache"].reshape(kv_cache_shape).npu()
output_kr_cache_data = input_tensors["kr_cache"].reshape(kr_cache_shape).npu()
w_dq_nz = torch_npu.npu_format_cast(input_tensors["w_dq"].reshape(w_dq_shape).npu().contiguous(), \
torch_npu.Format.FRACTAL_NZ)
w_dkvkr_nz = torch_npu.npu_format_cast(input_tensors["w_dkvkr"].reshape(w_dkv_kr_shape).npu().contiguous(), \
torch_npu.Format.FRACTAL_NZ)
w_uqqr_nz = torch_npu.npu_format_cast(input_tensors["w_uqqr"].reshape(w_uq_qr_shape).npu().contiguous(), \
torch_npu.Format.FRACTAL_NZ)
input_tensors["w_uqqr"] = w_uqqr_nz
input_tensors["w_dkvkr"] = w_dkvkr_nz
input_tensors["w_dq"] = w_dq_nz
token_x_data = input_tensors["x"].reshape(token_x_shape).npu()
w_dq_data = input_tensors["w_dq"].reshape(w_dq_shape).npu()
w_uq_qr_data = input_tensors["w_uqqr"].reshape(w_uq_qr_shape).npu()
w_uk_data = input_tensors["w_uk"].reshape(w_uk_shape).npu()
w_dkv_kr_data = input_tensors["w_dkvkr"].reshape(w_dkv_kr_shape).npu()
rmsnorm_gamma_cq_data = \
input_tensors["gamma_cq"].reshape(rmsnorm_gamma_cq_shape).npu()
rmsnorm_gamma_ckv_data = input_tensors["gamma_ckv"].reshape(rmsnorm_gamma_ckv_shape).npu()
rope_cos_data = input_tensors["cos"].reshape(rope_cos_shape).npu()
rope_sin_data = input_tensors["sin"].reshape(rope_cos_shape).npu()
cache_index_data = input_tensors["cache_index"].reshape(cache_index_shape).npu()
kv_cache_data = input_tensors["kv_cache"].reshape(kv_cache_shape).npu()
kr_cache_data = input_tensors["kr_cache"].reshape(kr_cache_shape).npu()
if is_quant_b:
k_scale = input_tensors["kv_quant_scale_cache"].npu()
k_scale_cache_data = k_scale
k_scale_cache_data_out = k_scale
k_scale_cache_data = k_scale
k_scale_cache_data_out = k_scale
else:
k_scale_cache_data = torch.zeros(k_scale_cache_out_shape, dtype=torch.float32).npu()
k_scale_cache_data_out = torch.zeros(k_scale_cache_out_shape, dtype=torch.float32).npu()
k_scale_cache_data = torch.zeros(k_scale_cache_out_shape, dtype=torch.float32).npu()
k_scale_cache_data_out = torch.zeros(k_scale_cache_out_shape, dtype=torch.float32).npu()
if is_quant_b:
dequant_scale_w_uq_qr_data = \
input_tensors["w_qb_scale"].reshape(dequant_scale_w_uq_qr_shape).npu()
dequant_scale_w_uq_qr_data = \
input_tensors["w_qb_scale"].reshape(dequant_scale_w_uq_qr_shape).npu()
else:
dequant_scale_w_uq_qr_data = torch.Tensor().npu()
input_data = [token_x_data, w_dq_data, w_uq_qr_data, dequant_scale_w_uq_qr_data,
w_uk_data, w_dkv_kr_data, rmsnorm_gamma_cq_data, rmsnorm_gamma_ckv_data,
rope_cos_data, rope_sin_data, cache_index_data,
kv_cache_data, kr_cache_data, k_scale_cache_data]
output_data = [output_q_norm_data, output_q_norm_scale_data, output_q_nope_data,
output_q_rope_data, output_kv_cache_data, output_kr_cache_data, k_scale_cache_data_out]
from mla_prolog_quant_impl import RopeTileShapeConfig
rope_tile_shape = RopeTileShapeConfig(two_dim=[32, 64], three_dim=[32, 32, 128], four_dim=[16, 128, 128, 128])
if is_p:
mla_prolog_quant_p(*input_data, *output_data, 1e-5, 1e-5, cache_mode, tile_config, rope_tile_shape)
else:
mla_prolog_quant_d(*input_data, *output_data, 1e-5, 1e-5, cache_mode, tile_config, rope_tile_shape)
torch_npu.npu.synchronize()
print("qNope =======")
compare(output_q_nope_data.cpu(), golden1.cpu(), "qNope", 0.005, 0.0078125, 0.005)
print("qRope =======")
compare(output_q_rope_data.cpu(), golden2.cpu(), "qRope", 0.005, 0.0078125, 0.005)
if is_quant_b:
print("qNorm =======")
compare(output_q_norm_data.cpu(), golden6.cpu(), "qNorm", 1.0, 0.0, 0.005)
print("qNormScale =======")
compare(output_q_norm_scale_data.cpu(), golden7.cpu(), "qNormScale", 0.000025, 0.005, 0.005)
else:
print("qNorm =======")
compare(output_q_norm_data.cpu(), golden6.cpu(), "qNorm", 0.0001, 0.0078125, 0.005)
print("kv =======")
if is_quant_b:
compare(output_kv_cache_data.cpu(), golden3.cpu(), "kv", 1.0, 0.0, 0)
else:
compare(output_kv_cache_data.cpu(), golden3.cpu(), "kv", 0.0001, 0.0078125, 0)
print("kr =======")
compare(output_kr_cache_data.cpu(), golden4.cpu(), "kr", 0.0001, 0.0078125, 0)
if is_quant_b:
print("kScaleCache =======")
compare(k_scale.cpu(), golden5.cpu(), "kScaleCache", 0.000025, 0.005, 0)
@pytest.mark.skip(reason="large shape")
def test_b128_s4k4_pa_nd_bf16_quantb_p():
'''
mla_prolog prefill测试函数
'''
torch.manual_seed(5)
prep_env()
params = {
'b': 128,
't': 128,
's': 1,
's1': 1,
's2': 1024,
'n1': 128,
'h': 7168,
'q_lora_rank': 1536,
'qk_nope_head_dim': 128,
'qk_rope_head_dim': 64,
'kv_lora_rank': 512,
'block_size': 128
}
dtype = pypto.DT_BF16
w_dtype = pypto.DT_INT8
is_quant_a, is_quant_b, is_nz = False, True, False
cache_mode = "PA_BSND"
tile_config = MlaTileConfig()
tile_config.tile_bs = 128
c0 = 16
m_tile_value = (min(128, tile_config.tile_bs) + c0 - 1) // c0 * c0
mv_tile_value = min(8, tile_config.tile_bs)
tile_config.m_tile = m_tile_value
tile_config.pre_quant_cube_tile[0] = m_tile_value
tile_config.pre_quant_cube_tile[1] = m_tile_value
tile_config.cube_qb_tile = [m_tile_value, m_tile_value, 256, 256, 256, 256]
tile_config.cube_wuk_tile = [m_tile_value, m_tile_value, 128, 128, 128, 128]
tile_config.mv_tile = mv_tile_value
tile_config.q_vec_tile0 = 32
tile_config.q_vec_tile1 = 128
tile_config.k_vec_tile0 = 32
tile_config.k_vec_tile1 = 512
tile_config.unroll_list = [128, 64, 32, 16, 8, 4, 2, 1]
actual_seq = torch.tensor([params["s2"]] * params["b"], dtype=torch.int32).unsqueeze(-1)
input_tensors, golden_data = gen_mla_prolog_quant_v32_data(params, (torch.bfloat16, torch.bfloat16), actual_seq, \
(is_quant_a, is_quant_b), False, 128, "PA_BSND")
mla_prolog_quant_v32(params, input_tensors, golden_data, dtype, w_dtype, \
is_quant_a, is_quant_b, is_nz, tile_config, cache_mode, is_p=True)
@pytest.mark.soc("950", "910")
def test_b4_s64k2_pa_nd_bf16_quantb_d():
'''
mla_prolog decode测试函数
'''
torch.manual_seed(5)
prep_env()
params = {
'b': 4,
't': 8,
's': 2,
's1': 2,
's2': 1024,
'n1': 128,
'h': 7168,
'q_lora_rank': 1536,
'qk_nope_head_dim': 128,
'qk_rope_head_dim': 64,
'kv_lora_rank': 512,
'block_size': 128
}
dtype = pypto.DT_BF16
w_dtype = pypto.DT_INT8
is_quant_a, is_quant_b, is_nz = False, True, False
cache_mode = "PA_BSND"
tile_config = MlaTileConfig()
tile_config.tile_bs = 8
c0 = 16
m_tile_value = (min(32, tile_config.tile_bs) + c0 - 1) // c0 * c0
mv_tile_value = min(8, tile_config.tile_bs)
tile_config.m_tile = m_tile_value
tile_config.pre_quant_cube_tile = [m_tile_value, m_tile_value, 256, 256, 128, 128]
tile_config.cube_qb_tile = [m_tile_value, m_tile_value, 256, 256, 256, 256]
tile_config.cube_wuk_tile = [tile_config.m_tile, tile_config.m_tile, 128, 128, 128, 128]
tile_config.mv_tile = mv_tile_value
tile_config.q_vec_tile0 = 1
tile_config.q_vec_tile1 = 32
tile_config.k_vec_tile0 = 2
tile_config.k_vec_tile1 = 512
tile_config.unroll_list = [8, 4, 2, 1]
actual_seq = torch.tensor([params["s2"]] * params["b"], dtype=torch.int32).unsqueeze(-1)
input_tensors, golden_data = gen_mla_prolog_quant_v32_data(params, (torch.bfloat16, torch.bfloat16), actual_seq, \
(is_quant_a, is_quant_b), False, 128, "PA_BSND")
mla_prolog_quant_v32(params, input_tensors, golden_data, dtype, w_dtype, \
is_quant_a, is_quant_b, is_nz, tile_config, cache_mode, is_p=False)
@pytest.mark.skip(reason="large shape")
def test_b64_s64k2_pa_nd_bf16_quantb_d():
'''
mla_prolog decode int8量化高吞吐测试用例
'''
torch.manual_seed(5)
prep_env()
params = {
'b': 64,
't': 128,
's': 2,
's1': 2,
's2': 1024,
'n1': 128,
'h': 7168,
'q_lora_rank': 1536,
'qk_nope_head_dim': 128,
'qk_rope_head_dim': 64,
'kv_lora_rank': 512,
'block_size': 128
}
dtype = pypto.DT_BF16
w_dtype = pypto.DT_INT8
is_quant_a, is_quant_b, is_nz = False, True, False
cache_mode = "PA_BSND"
tile_config = MlaTileConfig()
tile_config.tile_bs = 32
tile_config.m_tile = 128
tile_config.cube_qb_tile = [128, 128, 128, 256, 256, 256]
tile_config.cube_wuk_tile = [tile_config.m_tile, tile_config.m_tile, 128, 256, 256, 256]
tile_config.mv_tile = 8
tile_config.q_vec_tile0 = 32
tile_config.q_vec_tile1 = 128
tile_config.k_vec_tile0 = 32
tile_config.k_vec_tile1 = 512
if pypto.platform.npuarch == 'DAV_3510':
tile_config.pre_quant_cube_tile = [128, 128, 256, 256, 128, 128]
tile_config.unroll_list = [128, 64, 32, 16, 8, 4, 2, 1]
else:
tile_config.pre_quant_cube_tile = [32, 32, 256, 256, 128, 128]
tile_config.unroll_list = [64, 32, 16, 8, 4, 2, 1]
actual_seq = torch.tensor([params["s2"]] * params["b"], dtype=torch.int32).unsqueeze(-1)
input_tensors, golden_data = gen_mla_prolog_quant_v32_data(params, (torch.bfloat16, torch.bfloat16), actual_seq, \
(is_quant_a, is_quant_b), False, 128, "PA_BSND")
mla_prolog_quant_v32(params, input_tensors, golden_data, dtype, w_dtype, \
is_quant_a, is_quant_b, is_nz, tile_config, cache_mode, is_p=False)
@pytest.mark.soc("950")
def test_b4_s64k2_pa_nd_bf16_d():
'''
mla_prolog decode非量化测试函数
'''
torch.manual_seed(5)
prep_env()
params = {
'b': 4,
't': 8,
's': 2,
's1': 2,
's2': 64 * 1024,
'n1': 128,
'h': 7168,
'q_lora_rank': 1536,
'qk_nope_head_dim': 128,
'qk_rope_head_dim': 64,
'kv_lora_rank': 512,
'block_size': 128
}
dtype = pypto.DT_BF16
w_dtype = pypto.DT_INT8
is_quant_a, is_quant_b, is_nz = False, False, False
cache_mode = "PA_BSND"
tile_config = MlaTileConfig()
tile_config.tile_bs = 8
c0 = 16
m_tile_value = (min(32, tile_config.tile_bs) + c0 - 1) // c0 * c0
mv_tile_value = min(8, tile_config.tile_bs)
tile_config.m_tile = m_tile_value
tile_config.pre_quant_cube_tile = [m_tile_value, m_tile_value, 64, 256, 128, 128]
tile_config.cube_qb_tile = [m_tile_value, m_tile_value, 64, 256, 256, 256]
tile_config.cube_wuk_tile = [tile_config.m_tile, tile_config.m_tile, 128, 128, 128, 128]
tile_config.mv_tile = mv_tile_value
tile_config.q_vec_tile0 = 1
tile_config.q_vec_tile1 = 32
tile_config.k_vec_tile0 = 2
tile_config.k_vec_tile1 = 512
tile_config.unroll_list = [8, 4, 2, 1]
actual_seq = torch.tensor([params["s2"]] * params["b"], dtype=torch.int32).unsqueeze(-1)
input_tensors, golden_data = gen_mla_prolog_quant_v32_data(params, (torch.bfloat16, torch.bfloat16), actual_seq, \
(is_quant_a, is_quant_b), False, 128, "PA_BSND")
mla_prolog_quant_v32(params, input_tensors, golden_data, dtype, w_dtype, \
is_quant_a, is_quant_b, is_nz, tile_config, cache_mode, is_p=False)
@pytest.mark.soc("950")
def test_b64_s64k2_pa_nd_bf16_d():
'''
mla_prolog decode非量化测试函数
'''
torch.manual_seed(5)
prep_env()
params = {
'b': 64,
't': 128,
's': 2,
's1': 2,
's2': 64 * 1024,
'n1': 128,
'h': 7168,
'q_lora_rank': 1536,
'qk_nope_head_dim': 128,
'qk_rope_head_dim': 64,
'kv_lora_rank': 512,
'block_size': 128
}
dtype = pypto.DT_BF16
w_dtype = pypto.DT_INT8
is_quant_a, is_quant_b, is_nz = False, False, False
cache_mode = "PA_BSND"
tile_config = MlaTileConfig()
tile_config.tile_bs = 128
c0 = 16
m_tile_value = (min(128, tile_config.tile_bs) + c0 - 1) // c0 * c0
mv_tile_value = min(8, tile_config.tile_bs)
tile_config.m_tile = m_tile_value
if pypto.platform.npuarch == 'DAV_3510':
tile_config.pre_quant_cube_tile = [m_tile_value, m_tile_value, 64, 256, 128, 128]
tile_config.cube_qb_tile = [128, 128, 64, 256, 256, 256]
tile_config.cube_wuk_tile = [tile_config.m_tile, tile_config.m_tile, 128, 128, 128, 128]
else:
tile_config.pre_quant_cube_tile = [32, 32, 64, 256, 128, 128]
tile_config.cube_qb_tile = [128, 128, 64, 256, 256, 256]
tile_config.cube_wuk_tile = [tile_config.m_tile, tile_config.m_tile, 128, 256, 256, 256]
tile_config.mv_tile = mv_tile_value
tile_config.q_vec_tile0 = 32
tile_config.q_vec_tile1 = 128
tile_config.k_vec_tile0 = 32
tile_config.k_vec_tile1 = 512
if pypto.platform.npuarch == 'DAV_3510':
tile_config.unroll_list = [128, 64, 32, 16, 8, 4, 2, 1]
else:
tile_config.unroll_list = [64, 32, 16, 8, 4, 2, 1]
actual_seq = torch.tensor([params["s2"]] * params["b"], dtype=torch.int32).unsqueeze(-1)
input_tensors, golden_data = gen_mla_prolog_quant_v32_data(params, (torch.bfloat16, torch.bfloat16), actual_seq, \
(is_quant_a, is_quant_b), False, 128, "PA_BSND")
mla_prolog_quant_v32(params, input_tensors, golden_data, dtype, w_dtype, \
is_quant_a, is_quant_b, is_nz, tile_config, cache_mode, is_p=False)
if __name__ == "__main__":
logging.basicConfig(
format='%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s: %(message)s',
level=logging.INFO
)
test_b4_s64k2_pa_nd_bf16_quantb_d()