"""
"""
'''
'''
import os
from pathlib import Path
import math
import pytest
import torch
import torch_npu
import pypto
import logging
import numpy as np
from mla_prolog_quant_v4_impl import mla_prolog_v4, MlaPrologV4Attrs, \
MlaTileConfigs, MlaPrologV4Configs, check_input_output_shape_dtype, mla_prolog_quant_pypto
from utils.compare import compare
torch.manual_seed(5)
def prep_env():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
torch_npu.npu.config.allow_internal_format = True
def rms_norm_new(x, eps=1e-6):
x_dtype = x.dtype
mean_coff = 1.0 / x.shape[-1]
x_f32 = x.to(torch.float32)
square = x_f32 * x_f32
mean_res = square * mean_coff
reduce_sum = torch.sum(mean_res, dim=-1, keepdims=True) + eps
reduce_sqrt = torch.sqrt(reduce_sum)
res = x_f32 / reduce_sqrt
if x_dtype != torch.float32:
res = res.to(x_dtype)
return res
def rms_norm(x, gamma, eps=1e-6):
x_dtype = x.dtype
mean_coff = 1.0 / x.shape[-1]
gamma = gamma.to(torch.float32)
x_f32 = x.to(torch.float32)
square = x_f32 * x_f32
mean_res = square * mean_coff
reduce_sum = torch.sum(mean_res, dim=-1, keepdims=True) + eps
reduce_sqrt = torch.sqrt(reduce_sum)
res_div = x_f32 / reduce_sqrt
res = res_div * gamma
if x_dtype != torch.float32:
res = res.to(x_dtype)
return res
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.concatenate((-x2, x1), dim=-1)
def apply_rotary_pos_emb_v2(q, k, cos, sin, unsqueeze_dim=1):
input_dtype = q.dtype
q_clone = q.clone()
k_clone = k.clone()
t, nq, d = q.shape
t, nk, d = k.shape
q = q.reshape(t, nq, d//2, 2).permute(0, 1, 3, 2).reshape(t, nq, d)
k = k.reshape(t, nk, d//2, 2).permute(0, 1, 3, 2).reshape(t, nk, d)
q_t = rotate_half(q)
k_t = rotate_half(k)
q_new = q_t.reshape(t, nq, 2, d//2).permute(0, 1, 3, 2).reshape(t, nq, d)
k_new = k_t.reshape(t, nk, 2, d//2).permute(0, 1, 3, 2).reshape(t, nk, d)
q_new = q_new.to(torch.float32)
k_new = k_new.to(torch.float32)
cos = torch.unsqueeze(cos, dim=unsqueeze_dim)
sin = torch.unsqueeze(sin, dim=unsqueeze_dim)
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
q_embed = q_clone * cos + q_new * sin
k_embed = k_clone * cos + k_new * sin
if input_dtype != torch.float32:
q_embed, k_embed = q_embed.to(input_dtype), k_embed.to(input_dtype)
return q_embed, k_embed
def tensor_to_file(t: torch.Tensor, output: Path):
with open(str(output), "wb") as f:
dtype = t.dtype
if dtype == torch.bfloat16:
dtype = torch.int16
for each in t:
f.write(each.view(dtype).cpu().numpy().tobytes())
def quant(input_t, is_pertoken: bool = True, has_smooth=False, smooth_cq=None):
input_fp32 = input_t.to(torch.float32)
if has_smooth:
input_fp32 = input_fp32 * smooth_cq
abs_res = torch.abs(input_fp32)
reduce_idx = -1
if not is_pertoken:
reduce_idx = -2
max_value = torch.max(abs_res, dim=reduce_idx, keepdims=True)[0]
scale_quant = 127 / max_value
out_fp32 = input_fp32 * scale_quant
out_int32 = torch.round(out_fp32).to(torch.int32)
out_fp16 = out_int32.to(torch.float16)
out_int8 = torch.trunc(out_fp16).to(torch.int8)
scale_dequant = 1 / scale_quant
return out_int8, scale_dequant
def mla_prolog_compute(inputs):
dtype = inputs.get("dtype")
x = inputs.get("x")
wq_a = inputs.get("wq_a")
wq_b = inputs.get("wq_b")
qk_rope_head_dim = inputs.get("qk_rope_head_dim")
gamma_cq = inputs.get("gamma_cq")
gamma_ckv = inputs.get("gamma_ckv")
w_kv = inputs.get("w_kv")
cos = inputs.get("cos")
sin = inputs.get("sin")
wq_b_scale = inputs.get("wq_b_scale")
t, h = x.shape
qk_rope_head_dim = cos.shape[1]
head_dim = w_kv.shape[-1]
num_heads = wq_b.shape[-1] // head_dim
''' q '''
q_a_proj = torch.matmul(x.to(torch.float32), wq_a.to(torch.float32))
q_a_layernorm = rms_norm(q_a_proj, gamma_cq)
q_a_layernorm_out = q_a_layernorm.to(torch.bfloat16)
q_a_quant, q_a_quant_scale = quant(q_a_layernorm, True)
q_b_proj = torch.matmul(q_a_quant.to(torch.int32), wq_b.to(torch.int32))
''' dequant'''
q_b_proj_fp32 = q_b_proj.to(torch.float32)
q_b_proj_dequant = q_b_proj_fp32 * q_a_quant_scale
q_b_proj_dequant = q_b_proj_dequant * wq_b_scale
q_reshape = q_b_proj_dequant.reshape(t, num_heads, head_dim)
q_reshape = rms_norm_new(q_reshape)
q_reshape = q_reshape.to(torch.bfloat16)
""" kv """
kv_a_proj = torch.matmul(x.to(torch.float32), w_kv.to(torch.float32))
kv_a_proj_norm = rms_norm(kv_a_proj, gamma_ckv)
kv_reshape = kv_a_proj_norm.reshape(t, head_dim)
kv_reshape = kv_reshape.to(torch.bfloat16)
""" rope"""
q_pe = q_reshape[:, :, -qk_rope_head_dim:]
k_pe = kv_reshape[:, -qk_rope_head_dim:]
k_pe_r = k_pe.reshape(t, 1, qk_rope_head_dim)
q_embed, k_embed = apply_rotary_pos_emb_v2(q_pe, k_pe_r, cos, sin, 1)
k_embed_r = k_embed.reshape(t, qk_rope_head_dim)
q_out = torch.concat([q_reshape[:, :, :-qk_rope_head_dim], q_embed], -1)
kv_out = torch.concat([kv_reshape[:, :-qk_rope_head_dim], k_embed_r], -1)
return q_out, kv_out, q_a_quant, q_a_quant_scale
def gen_mla_prolog_input_data(params, dtypes, is_quant=True, is_nz=False):
dtype = dtypes
t = params.get("t")
h = params.get("h")
n = params.get("num_heads")
q_lora_rank = params.get("q_lora_rank")
head_dim = params.get("head_dim")
rope_head_dim = params.get("qk_rope_head_dim")
x_shape = [t, h]
wq_a_shape = [h, q_lora_rank]
gamma_cq_shape = [q_lora_rank]
gamma_ckv_shape = [head_dim]
wq_b_shape = [q_lora_rank, n * head_dim]
wkv_shape = [h, head_dim]
cos_shape = [t, rope_head_dim]
res = [None] * 9
x = torch.empty(x_shape, dtype=dtype).uniform_(-1, 1)
res[0] = x
wq_a = torch.empty(wq_a_shape, dtype=dtype).uniform_(-0.1, 0.1)
wq_b = torch.empty(wq_b_shape, dtype=dtype).uniform_(-0.1, 0.1)
w_kv = torch.empty(wkv_shape, dtype=dtype).uniform_(-0.1, 0.1)
if is_quant:
wq_b, wq_b_scale = quant(wq_b, False)
res[-1] = wq_b_scale
res[1] = wq_a
res[2] = wq_b
res[3] = w_kv
gamma_cq = torch.empty(gamma_cq_shape, dtype=dtype).uniform_(-1, 1)
gamma_ckv = torch.empty(gamma_ckv_shape, dtype=dtype).uniform_(-1, 1)
res[4] = gamma_cq
res[5] = gamma_ckv
cos = torch.empty(cos_shape, dtype=dtype).uniform_(-1, 1)
sin = torch.empty(cos_shape, dtype=dtype).uniform_(-1, 1)
res[6] = cos
res[7] = sin
return res
def gen_mla_prolog_data(params, dtype, is_quant=True, is_nz=False):
x, wq_a, wq_b, w_kv, gamma_cq, gamma_ckv, cos, sin, wq_b_scale = \
gen_mla_prolog_input_data(params, dtype, is_quant, is_nz)
inputs = dict()
inputs["x"] = x
inputs["wq_a"] = wq_a
inputs["wq_b"] = wq_b
inputs["w_kv"] = w_kv
inputs["cos"] = cos
inputs["sin"] = sin
inputs["gamma_cq"] = gamma_cq
inputs["gamma_ckv"] = gamma_ckv
inputs["wq_b_scale"] = wq_b_scale
q_out, kv_out, qr_out, qr_scale_out = mla_prolog_compute(inputs)
outputs = {"q_golden": q_out, "kv_golden": kv_out, "qr_golden": qr_out, "qr_scale_golden": qr_scale_out}
return inputs, outputs
def convert_pypto_to_torch_type(pypto_type):
if pypto_type == pypto.DataType.DT_INT8:
return torch.int8
elif pypto_type == pypto.DataType.DT_INT32:
return torch.int32
elif pypto_type == pypto.DataType.DT_FP32:
return torch.float32
elif pypto_type == pypto.DataType.DT_FP16:
return torch.float16
elif pypto_type == pypto.DataType.DT_BF16:
return torch.bfloat16
else:
raise ValueError(f"Unsupported pypto.DataType: {pypto_type}")
class MLA_MODEL(torch.nn.Module):
def forward(self, token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv, wq_b_scale):
return mla_prolog_quant_pypto(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv, wq_b_scale)
def mla_prolog(params, input_tensors, golden_tensors, dtype, is_nz):
d_type = pypto.DataType.DT_FP16 if dtype == pypto.DataType.DT_FP16 else pypto.DataType.DT_BF16
t = params['t']
n1 = params["num_heads"]
h = params["h"]
q_lora_rank = params["q_lora_rank"]
qk_rope_head_dim = params["qk_rope_head_dim"]
head_dim = params["head_dim"]
token_x_shape = [t, h]
wq_a_shape = [h, q_lora_rank]
wq_b_shape = [q_lora_rank, n1 * head_dim]
wkv_shape = [h, head_dim]
rope_cos_shape = [t, qk_rope_head_dim]
rmsnorm_gamma_cq_shape = [q_lora_rank]
rmsnorm_gamma_ckv_shape = [head_dim]
wq_b_scale_shape = [n1 * head_dim, 1]
q_out_shape = [t, n1, head_dim]
kv_out_shape = [t, head_dim]
qr_out_shape = [t, q_lora_rank]
qr_scale_shape = [t, 1]
if is_nz:
wq_a_nz = torch_npu.npu_format_cast(input_tensors["wq_a"].reshape(wq_a_shape).npu().contiguous(), \
torch_npu.Format.FRACTAL_NZ)
wq_b_nz = torch_npu.npu_format_cast(input_tensors["wq_b"].reshape(wq_b_shape).npu().contiguous(), \
torch_npu.Format.FRACTAL_NZ)
wkv_nz = torch_npu.npu_format_cast(input_tensors["w_kv"].reshape(wkv_shape).npu().contiguous(), \
torch_npu.Format.FRACTAL_NZ)
input_tensors["wq_a"] = wq_a_nz
input_tensors["wq_b"] = wq_b_nz
input_tensors["w_kv"] = wkv_nz
token_x = input_tensors["x"].reshape(token_x_shape).npu()
wq_a = input_tensors["wq_a"].reshape(wq_a_shape).npu()
wq_b = input_tensors["wq_b"].reshape(wq_b_shape).npu()
wkv = input_tensors["w_kv"].reshape(wkv_shape).npu()
rope_cos = input_tensors["cos"].reshape(rope_cos_shape).npu()
rope_sin = input_tensors["sin"].reshape(rope_cos_shape).npu()
gamma_cq = input_tensors["gamma_cq"].reshape(rmsnorm_gamma_cq_shape).npu()
gamma_ckv = input_tensors["gamma_ckv"].reshape(rmsnorm_gamma_ckv_shape).npu()
wq_b_scale = input_tensors["wq_b_scale"].reshape(wq_b_scale_shape).npu()
inputs = [token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv, wq_b_scale]
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
compiler_config = CompilerConfig()
compiler_config.mode = "reduce-overhead"
npu_backend = tng.get_npu_backend(compiler_config=compiler_config)
model = torch.compile(MLA_MODEL(), dynamic=False, fullgraph=True, backend=npu_backend)
output_q_data, output_kv_data, output_qr_data, output_qr_scale_data = model(*inputs)
pypto.runtime._device_synchronize()
golden1 = golden_tensors["q_golden"].reshape(q_out_shape)
golden2 = golden_tensors["kv_golden"].reshape(kv_out_shape)
golden3 = golden_tensors["qr_golden"].reshape(qr_out_shape)
golden4 = golden_tensors["qr_scale_golden"].reshape(qr_scale_shape)
print("q ================")
compare(output_q_data.cpu(), golden1.cpu(), "qOut", 0.0001, 0.0078125, 0.005)
print("kv ================")
compare(output_kv_data.cpu(), golden2.cpu(), "kvOut", 0.0001, 0.0078125, 0.005)
print("qr ================")
compare(output_qr_data.cpu(), golden3.cpu(), "qrOut", 1, 0, 0)
print(" qr_scale ==========")
compare(output_qr_scale_data.cpu(), golden4.cpu(), "qrScaleOut", 0.000025, 0.005, 0.005)
print("=========== pass ==========")
def mla_prolog_eager(params, input_tensors, golden_tensors, dtype, is_nz, attrs, configs):
d_type = pypto.DataType.DT_FP16 if dtype == pypto.DataType.DT_FP16 else pypto.DataType.DT_BF16
t = params['t']
n1 = params["num_heads"]
h = params["h"]
q_lora_rank = params["q_lora_rank"]
qk_rope_head_dim = params["qk_rope_head_dim"]
head_dim = params["head_dim"]
token_x_shape = [t, h]
wq_a_shape = [h, q_lora_rank]
wq_b_shape = [q_lora_rank, n1 * head_dim]
wkv_shape = [h, head_dim]
rope_cos_shape = [t, qk_rope_head_dim]
rmsnorm_gamma_cq_shape = [q_lora_rank]
rmsnorm_gamma_ckv_shape = [head_dim]
wq_b_scale_shape = [n1 * head_dim, 1]
q_out_shape = [t, n1, head_dim]
kv_out_shape = [t, head_dim]
qr_out_shape = [t, q_lora_rank]
qr_scale_shape = [t, 1]
if is_nz:
wq_a_nz = torch_npu.npu_format_cast(input_tensors["wq_a"].reshape(wq_a_shape).npu().contiguous(), \
torch_npu.Format.FRACTAL_NZ)
wq_b_nz = torch_npu.npu_format_cast(input_tensors["wq_b"].reshape(wq_b_shape).npu().contiguous(), \
torch_npu.Format.FRACTAL_NZ)
wkv_nz = torch_npu.npu_format_cast(input_tensors["w_kv"].reshape(wkv_shape).npu().contiguous(), \
torch_npu.Format.FRACTAL_NZ)
input_tensors["wq_a"] = wq_a_nz
input_tensors["wq_b"] = wq_b_nz
input_tensors["w_kv"] = wkv_nz
token_x = input_tensors["x"].reshape(token_x_shape).npu()
wq_a = input_tensors["wq_a"].reshape(wq_a_shape).npu()
wq_b = input_tensors["wq_b"].reshape(wq_b_shape).npu()
wkv = input_tensors["w_kv"].reshape(wkv_shape).npu()
rope_cos = input_tensors["cos"].reshape(rope_cos_shape).npu()
rope_sin = input_tensors["sin"].reshape(rope_cos_shape).npu()
gamma_cq = input_tensors["gamma_cq"].reshape(rmsnorm_gamma_cq_shape).npu()
gamma_ckv = input_tensors["gamma_ckv"].reshape(rmsnorm_gamma_ckv_shape).npu()
wq_b_scale = input_tensors["wq_b_scale"].reshape(wq_b_scale_shape).npu()
output_q_data = torch.empty([token_x.size(0), wq_b.size(1) // gamma_ckv.size(0), gamma_ckv.size(0)], dtype=token_x.dtype, device=f'{token_x.device}')
output_kv_data = torch.empty([token_x.size(0), gamma_ckv.size(0)], dtype=token_x.dtype, device=f'{token_x.device}')
output_qr_data = torch.empty([token_x.size(0), gamma_cq.size(0)], dtype=torch.int8, device=f'{token_x.device}')
output_qr_scale_data = torch.empty([token_x.size(0), 1], dtype=torch.float32, device=f'{token_x.device}')
check_input_output_shape_dtype(token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv, wq_b_scale,
output_q_data, output_kv_data, output_qr_data, output_qr_scale_data)
tile_configs = MlaTileConfigs(
two_dim_tile=[1, 64],
three_dim_tile=[1, 64, 64],
four_dim_tile=[1, 64, 64, 64],
vec_tile=[max(1, token_x.shape[0]//16), 64]
)
params_info = [token_x, wq_a, wq_b, wkv, gamma_cq, gamma_ckv, rope_cos, rope_sin, wq_b_scale, output_q_data, output_kv_data, output_qr_data, output_qr_scale_data]
mla_prolog_v4(*params_info, attrs, configs, tile_configs)
pypto.runtime._device_synchronize()
golden1 = golden_tensors["q_golden"].reshape(q_out_shape)
golden2 = golden_tensors["kv_golden"].reshape(kv_out_shape)
golden3 = golden_tensors["qr_golden"].reshape(qr_out_shape)
golden4 = golden_tensors["qr_scale_golden"].reshape(qr_scale_shape)
print("q ================")
compare(output_q_data.cpu(), golden1.cpu(), "qOut", 0.0001, 0.0078125, 0.005)
print("kv ================")
compare(output_kv_data.cpu(), golden2.cpu(), "kvOut", 0.0001, 0.0078125, 0.005)
print("qr ================")
compare(output_qr_data.cpu(), golden3.cpu(), "qrOut", 1, 0, 0)
print(" qr_scale ==========")
compare(output_qr_scale_data.cpu(), golden4.cpu(), "qrScaleOut", 0.000025, 0.005, 0.005)
print("=========== pass ==========")
@pytest.mark.skip("t=4")
def test_t4_pa_nd_bf16():
prep_env()
params = {
't': 4,
'num_heads': 64,
'h': 4096,
'q_lora_rank': 1024,
'head_dim': 512,
'qk_rope_head_dim': 64,
}
dtype = pypto.DataType.DT_BF16
is_quant = True
is_nz = True
input_tensors, golden_data = gen_mla_prolog_data(params, torch.bfloat16, is_quant, is_nz)
mla_prolog(params, input_tensors, golden_data, dtype, is_nz)
@pytest.mark.skip("t=16")
def test_t4_pa_nd_bf16_eager():
prep_env()
params = {
't': 4,
'num_heads': 64,
'h': 4096,
'q_lora_rank': 1024,
'head_dim': 512,
'qk_rope_head_dim': 64,
}
dtype = pypto.DataType.DT_BF16
is_quant = True
is_nz = True
attrs = MlaPrologV4Attrs(eps=1e-6)
configs = MlaPrologV4Configs(unroll_list=[128, 64, 32, 16, 1],
cube_l1_reuse_setting={2: 4},
mg_copyin_upper_bound=2 * 1024 * 1024,
pg_upper_bound=8192,
block_size=128,
t_sub_tile=1,
chunk_size=2)
input_tensors, golden_data = gen_mla_prolog_data(params, torch.bfloat16, is_quant, is_nz)
mla_prolog_eager(params, input_tensors, golden_data, dtype, is_nz, attrs, configs)
@pytest.mark.skip("t=16")
def test_t16_pa_nd_bf16():
prep_env()
params = {
't': 16,
'num_heads': 64,
'h': 4096,
'q_lora_rank': 1024,
'head_dim': 512,
'qk_rope_head_dim': 64,
}
dtype = pypto.DataType.DT_BF16
is_nz = True
is_quant = True
input_tensors, golden_data = gen_mla_prolog_data(params, torch.bfloat16, is_quant, is_nz)
mla_prolog(params, input_tensors, golden_data, dtype, is_nz)
@pytest.mark.skip("t=16")
def test_t16_pa_nd_bf16_eager():
prep_env()
params = {
't': 16,
'num_heads': 64,
'h': 4096,
'q_lora_rank': 1024,
'head_dim': 512,
'qk_rope_head_dim': 64,
}
dtype = pypto.DataType.DT_BF16
is_quant = True
is_nz = True
attrs = MlaPrologV4Attrs(eps=1e-6)
configs = MlaPrologV4Configs(unroll_list=[128, 64, 32, 16, 1],
cube_l1_reuse_setting={2: 4},
mg_copyin_upper_bound=2 * 1024 * 1024,
pg_upper_bound=8192,
block_size=128,
t_sub_tile=1,
chunk_size=2)
input_tensors, golden_data = gen_mla_prolog_data(params, torch.bfloat16, is_quant, is_nz)
mla_prolog_eager(params, input_tensors, golden_data, dtype, is_nz, attrs, configs)
@pytest.mark.skip("t=512")
def test_t512_pa_nd_bf16():
prep_env()
params = {
't': 512,
'num_heads': 64,
'h': 4096,
'q_lora_rank': 1024,
'head_dim': 512,
'qk_rope_head_dim': 64,
}
dtype = pypto.DataType.DT_BF16
is_nz = True
is_quant = True
input_tensors, golden_data = gen_mla_prolog_data(params, torch.bfloat16, is_quant, is_nz)
mla_prolog(params, input_tensors, golden_data, dtype, is_nz)
if __name__ == "__main__":
logging.basicConfig(
format='%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s: %(message)s',
level=logging.INFO
)
test_t4_pa_nd_bf16()