"""
本脚本有 2 种执行模式:
1. CI批跑时, 由 cmake/scripts/golden_ctrl.py 调用, 为避免日志过多, 此时 logging 级别为 logging.INFO;
2. 单独调试时, 本脚本单独被调用, 此时 logging 级别为 logging.DEBUG;
"""
import math
import sys
import logging
from pathlib import Path
from typing import List
import numpy as np
import torch
from ml_dtypes import bfloat16
if __name__ == "__main__":
""" 单独调试时配置 """
logging.basicConfig(format='%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s: %(message)s',
level=logging.DEBUG)
g_src_root: Path = Path(Path(__file__).parent, "../../../../../").resolve()
logging.debug("SrcRoot: %s", g_src_root)
g_ctrl_path: Path = Path(g_src_root, "cmake/scripts")
if str(g_ctrl_path) not in sys.path:
sys.path.append(str(g_ctrl_path))
from golden_register import GoldenRegister
else:
from golden_register import GoldenRegister
def gen_post_quantize_np(input_fp32):
abs_res = np.abs(input_fp32)
max_value = np.max(abs_res, axis=-1, keepdims=True)
scale_quant = 127.0 / max_value
out_fp32 = input_fp32 * scale_quant
out_int32 = np.rint(out_fp32).astype(np.int32)
out_int8 = np.clip(out_int32, -128, 127).astype(np.int8)
scale_dequant = 1.0 / scale_quant
return out_int8, scale_dequant
def gen_post_quant_mm_np(a, w, scale_w, output: Path):
a_fp32 = a.astype(np.float32)
quantized_a, scale_dequant_a = gen_post_quantize_np(a_fp32)
quant1_int8_path = Path(output, 'quant0_int8.bin')
quant1_fp32_path = Path(output, 'quant0_fp32.bin')
quantized_a.tofile(quant1_int8_path)
scale_dequant_a.tofile(quant1_fp32_path)
a_int32 = quantized_a.astype(np.int32)
w_int32 = w.astype(np.int32)
res_int32 = np.matmul(a_int32, w_int32)
mm5_int32_path = Path(output, 'mm5_int32.bin')
res_int32.tofile(mm5_int32_path)
res = res_int32.astype(np.float32)
mm5_fp32_path = Path(output, 'mm5_fp32.bin')
res.tofile(mm5_fp32_path)
res = res * scale_dequant_a
res = res * scale_w
return res.astype(a.dtype)
def gen_post_data(output_dir: Path, params, dtype, pa_out=None, is_quant=True, is_nz=True):
input_b = params.get("b")
input_s = params.get("s")
input_n = params.get("num_heads")
input_h = params.get("h")
kv_lora_rank = params.get("kv_lora_rank")
v_head_dim = params.get("v_head_dim")
params_path = Path(output_dir, 'params.bin')
input_path = Path(output_dir, 'input.bin')
t1_path = Path(output_dir, 't1.bin')
r1_path = Path(output_dir, 'r1.bin')
cast1_path = Path(output_dir, 'cast1.bin')
w_uv_path = Path(output_dir, 'w_uv.bin')
w_uv_scale_w_path = Path(output_dir, 'w_uv_scale_w.bin')
cast0_out_path = Path(output_dir, 'cast0_out.bin')
abs_out_path = Path(output_dir, 'abs_out.bin')
mul0_out_path = Path(output_dir, 'mul0_out.bin')
rms_out_path = Path(output_dir, 'rms_out.bin')
quant1_int8_path = Path(output_dir, 'quant0_int8.bin')
quant1_fp32_path = Path(output_dir, 'quant0_fp32.bin')
bmm4_int32_path = Path(output_dir, 'bmm4_int32.bin')
bmm4_path = Path(output_dir, 'bmm4.bin')
t3_path = Path(output_dir, 't3.bin')
r2_path = Path(output_dir, 'r2.bin')
w_o_path = Path(output_dir, 'w_o.bin')
w_o_nd_path = Path(output_dir, 'w_o_nd.bin')
w_o_scale_w_path = Path(output_dir, 'w_o_scale_w.bin')
bmm5_path = Path(output_dir, 'bmm5.bin')
attn_output_path = Path(output_dir, 'attn_output.bin')
complete = (params_path.exists() and input_path.exists() and attn_output_path.exists())
complete = False
if complete:
logging.debug("Case(), Golden complete.")
else:
logging.debug("Start generate papost golden ...")
dtype_num = 0
if dtype == np.float32:
dtype_num = 0
elif dtype == np.float16:
dtype_num = 1
elif dtype == bfloat16:
dtype_num = 2
params_num = np.array([input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim, dtype_num], dtype=np.int64)
np.random.seed(0)
input_t = pa_out
if pa_out is None:
input_t = np.random.randn(input_b * input_n * input_s, kv_lora_rank).astype(np.float32)
w_uv = np.random.randn(input_n, kv_lora_rank, v_head_dim).astype(dtype)
w_uv_scale_w = np.random.randn(input_n, 1, v_head_dim).astype(np.float32) * 0.001
w_o = np.random.randint(-128, 128, size=(input_n * v_head_dim, input_h), dtype=np.int8)
w_o_scale_w = np.random.randn(input_h).astype(np.float32) * 0.001
w_o_scale_w.tofile(w_o_scale_w_path)
params_num.tofile(params_path)
input_t.tofile(input_path)
w_uv.tofile(w_uv_path)
w_uv_scale_w.tofile(w_uv_scale_w_path)
w_o.tofile(w_o_nd_path)
w_o_nz = np.reshape(w_o, (input_n * v_head_dim, input_h // 32, 32))
w_o_nz = np.transpose(w_o_nz, (1, 0, 2))
w_o_nz.tofile(w_o_path)
cast1 = input_t.astype(dtype)
cast1.tofile(cast1_path)
r1 = np.reshape(cast1, (input_b * input_s, input_n, kv_lora_rank))
r1.tofile(r1_path)
t1 = np.transpose(r1, (1, 0, 2))
t1.tofile(t1_path)
calc_input = t1
bmm4 = np.matmul(calc_input.astype(np.float32), w_uv.astype(np.float32))
if dtype != np.float32:
bmm4 = bmm4.astype(dtype)
bmm4.tofile(bmm4_path)
t3 = np.transpose(bmm4, (1, 0, 2))
t3.tofile(t3_path)
r2 = np.reshape(t3, (input_b * input_s, input_n * v_head_dim))
r2.tofile(r2_path)
bmm5_i = r2
bmm5 = gen_post_quant_mm_np(bmm5_i, w_o, w_o_scale_w, output_dir)
bmm5.tofile(bmm5_path)
bmm5 = np.reshape(bmm5, (input_b, input_s, input_h))
bmm5.tofile(attn_output_path)
return bmm5
def split_pa_tensor_by_b(input_tensor):
split_data = np.split(input_tensor, input_tensor.shape[0])
return split_data
def gen_pa_softmax(x):
x = x.astype(np.float32)
x_max = x.max(axis=-1, keepdims=True)
x_sub = x - x_max
y = np.exp(x_sub)
x_sum = y.sum(axis=-1, keepdims=True)
ans = y
return ans, x_sum, x_max
def gen_pa_uniform_data(data_shape, min_value, max_value, dtype):
if min_value == 0 and max_value == 0:
return np.zeros(data_shape, dtype=dtype)
if dtype == np.bool_:
return np.random.choice([True, False], size=data_shape)
return np.random.uniform(low=min_value, high=max_value, size=data_shape).astype(
dtype
)
def trans_pa_bnsd_to_bsh(tensor, shape):
if len(shape) == 4:
b = shape[0]
n = shape[1]
s = shape[2]
d = shape[3]
h = n * d
return tensor.transpose(0, 2, 1, 3).reshape(b, s, h)
else:
return tensor
def gen_pa_data(output_dir: Path, params, dtype, q_out, q_rope_out, kv_cache_out, kr_cache_out, block_size=4096,
n_tile=128, is_nz=False):
np.random.seed(0)
b = params.get("b")
n_q = params.get("num_heads")
skv = params.get("s2")
s_q = params.get("s")
n_kv = params.get("n2")
kv_lora_rank = params.get("kv_lora_rank")
qk_rope_dim = params.get("qk_rope_head_dim")
d_q = kv_lora_rank + qk_rope_dim
d_k = kv_lora_rank + qk_rope_dim
d_v = kv_lora_rank
scalar = d_q ** -0.5
if isinstance(skv, int):
actual_seq_len = [skv] * b
elif isinstance(skv, list):
if len(skv) == b:
actual_seq_len = skv
else:
raise RuntimeError("unsupported skv list length")
else:
raise RuntimeError("unsupported skv data type")
s_max = max(actual_seq_len)
shape_q1 = [b, n_q, s_q, kv_lora_rank]
shape_q2 = [b, n_q, s_q, qk_rope_dim]
shape_k1 = [b, n_kv, s_max, kv_lora_rank]
shape_k2 = [b, n_kv, s_max, qk_rope_dim]
shape_k = [b, n_kv, s_max, d_k]
shape_v = [b, n_kv, s_max, d_v]
atten_out_shape = [b, n_q, s_q, d_v]
block_num_per_batch = []
block_num_min = 0
block_num = 0
q_bnsd1 = q_out
q_bnsd2 = q_rope_out
k_bnsd1 = kv_cache_out
k_bnsd2 = kr_cache_out
v_bnsd = k_bnsd1
if q_out is None or q_rope_out is None or kv_cache_out is None or kr_cache_out is None:
q_bnsd1 = gen_pa_uniform_data(shape_q1, -1, 1, dtype)
q_bnsd2 = gen_pa_uniform_data(shape_q2, -1, 1, dtype)
k_bnsd1 = gen_pa_uniform_data(shape_k1, -1, 1, dtype)
k_bnsd2 = gen_pa_uniform_data(shape_k2, -1, 1, dtype)
v_bnsd = k_bnsd1
q_bnsd = np.concatenate((q_bnsd1, q_bnsd2), axis=-1)
k_bnsd = np.concatenate((k_bnsd1, k_bnsd2), axis=-1)
for actual_seq in actual_seq_len:
block_num_per_batch.append(math.ceil(actual_seq / block_size))
block_num_min += math.ceil(actual_seq / block_size)
block_table_shape = [b, math.ceil(s_max / block_size)]
block_num = block_num_min
block_idx_list = np.arange(0, block_num, 1)
block_idx = 0
block_table = [-1] * block_table_shape[1]
block_table = np.tile(block_table, (block_table_shape[0], 1)).astype(np.int32)
block_table_batch_idx = 0
for idx in block_num_per_batch:
for j in range(idx):
block_table[block_table_batch_idx][j] = (block_idx_list[block_idx])
block_idx += 1
block_table_batch_idx += 1
logging.debug("block_table %s", block_table)
k_cache = np.zeros([block_num, block_size, n_kv * d_k]).astype(dtype)
v_cache = np.zeros([block_num, block_size, n_kv * d_v]).astype(dtype)
logging.debug("dtype %s shape %s ", type(k_bnsd), k_bnsd.shape)
k_tensor_bsh_raw = trans_pa_bnsd_to_bsh(k_bnsd, shape_k)
v_tensor_bsh_raw = trans_pa_bnsd_to_bsh(v_bnsd, shape_v)
k_tensor_bsh = np.zeros((b, block_table_shape[1] * block_size, n_kv * d_k)).astype(dtype)
v_tensor_bsh = np.zeros((b, block_table_shape[1] * block_size, n_kv * d_v)).astype(dtype)
k_tensor_bsh[:, :k_tensor_bsh_raw.shape[1], :] = k_tensor_bsh_raw[:, :, :]
v_tensor_bsh[:, :v_tensor_bsh_raw.shape[1], :] = v_tensor_bsh_raw[:, :, :]
for b_idx in range(b):
for block_i, kv_cache_blk_id in enumerate(block_table[b_idx]):
block_offset = block_i * block_size
if kv_cache_blk_id == -1:
continue
else:
k_cache[kv_cache_blk_id, 0:block_size, :] = k_tensor_bsh[
b_idx, block_offset:(block_offset + block_size), :]
v_cache[kv_cache_blk_id, 0:block_size, :] = v_tensor_bsh[
b_idx, block_offset:(block_offset + block_size), :]
attent_out = np.zeros(atten_out_shape, dtype=np.float32)
k_tensor_list = split_pa_tensor_by_b(k_bnsd)
v_tensor_list = split_pa_tensor_by_b(v_bnsd)
for b_index in range(b):
matmul_dtype = np.float32
act_seq = actual_seq_len[b_index]
k_sub_tensor = k_tensor_list[b_index]
v_sub_tensor = v_tensor_list[b_index]
q_tensor_cur = q_bnsd[b_index:(b_index + 1), :, :, :]
k_cur = k_sub_tensor[:, :, :act_seq, :]
v_cur = v_sub_tensor[:, :, :act_seq, :]
qk_bmm_res = np.matmul(q_tensor_cur, k_cur.transpose(0, 1, 3, 2), dtype=matmul_dtype)
qk_ele_res = qk_bmm_res * scalar
softmax_res, softmax_sum, softmax_max = gen_pa_softmax(qk_ele_res)
bmm2_res = np.matmul(softmax_res, v_cur, dtype=matmul_dtype) / softmax_sum
attent_out[b_index:(b_index + 1), :, :, :] = bmm2_res
attent_out = np.reshape(attent_out, (b * n_q * s_q, d_v))
q_nope = q_bnsd[:, :, :, : kv_lora_rank]
q_rope = q_bnsd[:, :, :, kv_lora_rank:]
k_cache_nope_h = kv_lora_rank * n_kv
k_cache_nope = k_cache[:, :, : k_cache_nope_h]
k_cache_rope = k_cache[:, :, k_cache_nope_h:]
k_cache_nope_nz = k_cache_nope.reshape(k_cache_nope.shape[0], k_cache_nope.shape[1], k_cache_nope.shape[2] // 16,
16)
k_cache_rope_nz = k_cache_rope.reshape(k_cache_rope.shape[0], k_cache_rope.shape[1], k_cache_rope.shape[2] // 16,
16)
v_cache_nz = v_cache.reshape(v_cache.shape[0], v_cache.shape[1], v_cache.shape[2] // 16, 16)
k_cache_nope_nz = np.transpose(k_cache_nope_nz, (0, 2, 1, 3))
k_cache_rope_nz = np.transpose(k_cache_rope_nz, (0, 2, 1, 3))
v_cache_nz = np.transpose(v_cache_nz, (0, 2, 1, 3))
input_params = np.array([b, s_q, n_q, n_kv, kv_lora_rank, qk_rope_dim, block_size, n_tile], dtype=np.int32)
q_nope_path = Path(output_dir, 'q_nope.bin')
q_nope.tofile(q_nope_path)
q_rope_path = Path(output_dir, 'q_rope.bin')
q_rope.tofile(q_rope_path)
k_cache_nope_path = Path(output_dir, 'k_cache_nope.bin')
k_cache_nope.tofile(k_cache_nope_path)
k_cache_rope_path = Path(output_dir, 'k_cache_rope.bin')
k_cache_rope.tofile(k_cache_rope_path)
v_cache_path = Path(output_dir, 'v_cache.bin')
v_cache.tofile(v_cache_path)
k_cache_nope_nz_path = Path(output_dir, 'k_cache_nope_nz.bin')
k_cache_nope_nz.tofile(k_cache_nope_nz_path )
kv_cache_nope_nz_path = Path(output_dir, 'kv_cache_nope_nz.bin')
k_cache_nope_nz.tofile(kv_cache_nope_nz_path)
k_cache_rope_nz_path = Path(output_dir, 'k_cache_rope_nz.bin')
k_cache_rope_nz.tofile(k_cache_rope_nz_path )
v_cache_nz_path = Path(output_dir, 'v_cache_nz.bin')
v_cache_nz.tofile(v_cache_nz_path)
block_table_path = Path(output_dir, 'block_table.bin')
block_table.tofile(block_table_path)
actual_seq_len_path = Path(output_dir, 'actual_seq_len.bin')
np.array(actual_seq_len).astype(np.int32).tofile(actual_seq_len_path)
attent_out_path = Path(output_dir, 'atten_out.bin')
attent_out.tofile(attent_out_path)
input_param_path = Path(output_dir, 'input_param.bin')
input_params.tofile(input_param_path)
return attent_out
def tensor_bf16_tofile(t: torch.Tensor, output: Path):
input_file_bin = open(str(output), "wb")
for each in t:
if t.dtype == torch.bfloat16:
input_file_bin.write(each.view(torch.int16).numpy().tobytes())
elif t.dtype == torch.float32:
input_file_bin.write(each.view(torch.int32).numpy().tobytes())
elif t.dtype == torch.int32:
input_file_bin.write(each.numpy().tobytes())
elif t.dtype == torch.int8:
input_file_bin.write(each.numpy().tobytes())
else:
raise ValueError(f"Unsupported dtype: {t.dtype}")
input_file_bin.close()
def gen_data_func_bf16(shape_size, dtype, case_name: str, output: Path):
input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim = shape_size
params_path = Path(output, 'params.bin')
input_path = Path(output, 'input.bin')
t1_path = Path(output, 't1.bin')
r1_path = Path(output, 'r1.bin')
t2_path = Path(output, 't2.bin')
w_uv_path = Path(output, 'w_uv.bin')
bmm4_path = Path(output, 'bmm4.bin')
t3_path = Path(output, 't3.bin')
r2_path = Path(output, 'r2.bin')
w_o_path = Path(output, 'w_o.bin')
bmm5_path = Path(output, 'bmm5.bin')
attn_output_path = Path(output, 'attn_output.bin')
complete = (params_path.exists() and input_path.exists() and t1_path.exists()
and r1_path.exists() and t2_path.exists() and w_uv_path.exists()
and bmm4_path.exists() and t3_path.exists() and r2_path.exists()
and w_o_path.exists() and attn_output_path.exists() and bmm5_path.exists())
if complete:
logging.debug("Case(%s), Golden complete.", case_name)
else:
dtype_num = 0
if dtype == torch.float32:
dtype_num = 0
elif dtype == torch.float16:
dtype_num = 1
elif dtype == torch.bfloat16:
dtype_num = 2
params = torch.tensor([input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim, dtype_num],
dtype=torch.int64)
input_t = torch.randn([input_b, input_n, input_s, kv_lora_rank], dtype=dtype)
w_uv = torch.randn([input_n, kv_lora_rank, v_head_dim], dtype=dtype)
w_o = torch.randn([input_n * v_head_dim, input_h], dtype=dtype)
params.numpy().tofile(params_path)
tensor_bf16_tofile(input_t, input_path)
tensor_bf16_tofile(w_uv, w_uv_path)
tensor_bf16_tofile(w_o, w_o_path)
t1 = input_t.transpose(1, 2)
tensor_bf16_tofile(t1, t1_path)
r1 = t1.reshape(input_b * input_s, input_n, kv_lora_rank)
tensor_bf16_tofile(r1, r1_path)
t2 = r1.transpose(0, 1)
tensor_bf16_tofile(t2, t2_path)
calc_input = t2
bmm4 = torch.matmul(calc_input.to(torch.float32), w_uv.to(torch.float32))
if dtype != torch.float32:
bmm4 = bmm4.to(dtype)
tensor_bf16_tofile(bmm4, bmm4_path)
t3 = bmm4.transpose(0, 1)
tensor_bf16_tofile(t3, t3_path)
r2 = t3.reshape(input_b * input_s, input_n * v_head_dim)
tensor_bf16_tofile(r2, r2_path)
bmm5_i = r2
bmm5 = torch.matmul(bmm5_i.to(torch.float32), w_o.to(torch.float32))
if dtype != torch.float32:
bmm5 = bmm5.to(dtype)
tensor_bf16_tofile(bmm5, bmm5_path)
bmm5 = bmm5.reshape(input_b, input_s, input_h)
tensor_bf16_tofile(bmm5, attn_output_path)
return True
def quantize_torch(input_fp32):
abs_res = torch.abs(input_fp32)
max_value, _ = torch.max(abs_res, dim=-1, keepdim=True)
scale_quant = 127.0 / max_value
out_fp32 = input_fp32 * scale_quant
out_int32 = torch.round(out_fp32).to(torch.int32)
out_int8 = torch.clamp(out_int32, -128, 127).to(torch.int8)
scale_dequant = 1.0 / scale_quant
return out_int8, scale_dequant
def gen_quant_mm_torch(a, w, scale_w, output: Path):
a_fp32 = a.to(torch.float32)
quantized_a, scale_dequant_a = quantize_torch(a_fp32)
quant1_int8_path = Path(output, 'quant0_int8.bin')
quant1_fp32_path = Path(output, 'quant0_fp32.bin')
tensor_bf16_tofile(quantized_a, quant1_int8_path)
tensor_bf16_tofile(scale_dequant_a, quant1_fp32_path)
a_int32 = quantized_a.to(torch.int32)
w_int32 = w.to(torch.int32)
res_int32 = torch.matmul(a_int32, w_int32)
mm5_int32_path = Path(output, 'mm5_int32.bin')
tensor_bf16_tofile(res_int32, mm5_int32_path)
res = res_int32.to(torch.float32)
mm5_fp32_path = Path(output, 'mm5_fp32.bin')
tensor_bf16_tofile(res, mm5_fp32_path)
res = res * scale_dequant_a
res = res * scale_w
return res.to(a.dtype)
def gen_data_func_bf16_quant_dynamic(shape_size, dtype, case_name: str, output: Path):
input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim = shape_size
params_path = Path(output, 'params.bin')
input_path = Path(output, 'input.bin')
t1_path = Path(output, 't1.bin')
r1_path = Path(output, 'r1.bin')
cast1_path = Path(output, 'cast1.bin')
w_uv_path = Path(output, 'w_uv.bin')
w_uv_scale_w_path = Path(output, 'w_uv_scale_w.bin')
cast0_out_path = Path(output, 'cast0_out.bin')
abs_out_path = Path(output, 'abs_out.bin')
mul0_out_path = Path(output, 'mul0_out.bin')
rms_out_path = Path(output, 'rms_out.bin')
quant1_int8_path = Path(output, 'quant0_int8.bin')
quant1_fp32_path = Path(output, 'quant0_fp32.bin')
bmm4_int32_path = Path(output, 'bmm4_int32.bin')
bmm4_path = Path(output, 'bmm4.bin')
t3_path = Path(output, 't3.bin')
r2_path = Path(output, 'r2.bin')
w_o_path = Path(output, 'w_o.bin')
w_o_scale_w_path = Path(output, 'w_o_scale_w.bin')
bmm5_path = Path(output, 'bmm5.bin')
attn_output_path = Path(output, 'attn_output.bin')
complete = (params_path.exists() and input_path.exists() and t1_path.exists() and r1_path.exists()
and cast1_path.exists() and w_uv_path.exists() and bmm4_path.exists() and t3_path.exists()
and r2_path.exists() and w_o_path.exists() and attn_output_path.exists() and bmm5_path.exists()
and w_uv_scale_w_path.exists() and w_o_scale_w_path.exists())
complete = False
if complete:
logging.debug("Case(%s), Golden complete.", case_name)
else:
dtype_num = 0
if dtype == torch.float32:
dtype_num = 0
elif dtype == torch.float16:
dtype_num = 1
elif dtype == torch.bfloat16:
dtype_num = 2
params = torch.tensor([input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim, dtype_num],
dtype=torch.int64)
input_t = torch.randn([input_b * input_n * input_s, kv_lora_rank], dtype=torch.float32)
w_uv = torch.randn([input_n, kv_lora_rank, v_head_dim], dtype=dtype)
w_uv_scale_w = torch.randn([input_n, 1, v_head_dim], dtype=torch.float32) * 0.001
w_o = torch.randint(size=(input_n * v_head_dim, input_h), low=-128, high=128, dtype=torch.int8)
w_o_scale_w = torch.randn([input_h], dtype=torch.float32) * 0.001
params.numpy().tofile(params_path)
tensor_bf16_tofile(input_t, input_path)
tensor_bf16_tofile(w_uv, w_uv_path)
tensor_bf16_tofile(w_uv_scale_w, w_uv_scale_w_path)
w_o_nz = w_o.reshape(input_n * v_head_dim, input_h // 32, 32)
w_o_nz = w_o_nz.transpose(0, 1)
tensor_bf16_tofile(w_o_nz, w_o_path)
tensor_bf16_tofile(w_o_scale_w, w_o_scale_w_path)
r1 = input_t.reshape(input_b * input_s, input_n, kv_lora_rank)
tensor_bf16_tofile(r1, r1_path)
cast1 = r1.to(dtype)
tensor_bf16_tofile(cast1, cast1_path)
t1 = cast1.transpose(0, 1)
tensor_bf16_tofile(t1, t1_path)
calc_input = t1
bmm4 = torch.matmul(calc_input.to(torch.float32), w_uv.to(torch.float32))
if dtype != torch.float32:
bmm4 = bmm4.to(dtype)
tensor_bf16_tofile(bmm4, bmm4_path)
t3 = bmm4.transpose(0, 1)
tensor_bf16_tofile(t3, t3_path)
r2 = t3.reshape(input_b * input_s, input_n * v_head_dim)
tensor_bf16_tofile(r2, r2_path)
bmm5_i = r2
bmm5 = gen_quant_mm_torch(bmm5_i, w_o, w_o_scale_w, output)
tensor_bf16_tofile(bmm5, bmm5_path)
bmm5 = bmm5.reshape(input_b, input_s, input_h)
tensor_bf16_tofile(bmm5, attn_output_path)
return True
def gen_data_func_bf16_quant_dynamic_cast(shape_size, dtype, case_name: str, output: Path):
input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim = shape_size
params_path = Path(output, 'params.bin')
input_path = Path(output, 'input.bin')
t1_path = Path(output, 't1.bin')
r1_path = Path(output, 'r1.bin')
cast1_path = Path(output, 'cast1.bin')
w_uv_path = Path(output, 'w_uv.bin')
w_uv_scale_w_path = Path(output, 'w_uv_scale_w.bin')
cast0_out_path = Path(output, 'cast0_out.bin')
abs_out_path = Path(output, 'abs_out.bin')
mul0_out_path = Path(output, 'mul0_out.bin')
rms_out_path = Path(output, 'rms_out.bin')
quant1_int8_path = Path(output, 'quant0_int8.bin')
quant1_fp32_path = Path(output, 'quant0_fp32.bin')
bmm4_int32_path = Path(output, 'bmm4_int32.bin')
bmm4_path = Path(output, 'bmm4.bin')
t3_path = Path(output, 't3.bin')
r2_path = Path(output, 'r2.bin')
w_o_path = Path(output, 'w_o.bin')
w_o_nd_path = Path(output, 'w_o_nd.bin')
w_o_scale_w_path = Path(output, 'w_o_scale_w.bin')
bmm5_path = Path(output, 'bmm5.bin')
attn_output_path = Path(output, 'attn_output.bin')
complete = (params_path.exists() and input_path.exists()
and t1_path.exists() and r1_path.exists()
and cast1_path.exists() and w_uv_path.exists()
and bmm4_path.exists() and t3_path.exists()
and r2_path.exists() and w_o_path.exists()
and attn_output_path.exists() and bmm5_path.exists()
and w_uv_scale_w_path.exists() and w_o_scale_w_path.exists())
complete = False
if complete:
logging.debug("Case(%s), Golden complete.", case_name)
else:
dtype_num = 0
if dtype == torch.float32:
dtype_num = 0
elif dtype == torch.float16:
dtype_num = 1
elif dtype == torch.bfloat16:
dtype_num = 2
params = torch.tensor([input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim, dtype_num],
dtype=torch.int64)
input_t = torch.randn([input_b * input_n * input_s, kv_lora_rank], dtype=torch.float32)
w_uv = torch.randn([input_n, kv_lora_rank, v_head_dim], dtype=dtype)
w_uv = torch.randn([input_n, kv_lora_rank, v_head_dim], dtype=dtype)
w_uv_scale_w = torch.randn([input_n, 1, v_head_dim], dtype=torch.float32) * 0.001
w_o = torch.randint(size=(input_n * v_head_dim, input_h), low=-128, high=128, dtype=torch.int8)
w_o_scale_w = torch.randn([input_h], dtype=torch.float32) * 0.001
params.numpy().tofile(params_path)
tensor_bf16_tofile(input_t, input_path)
tensor_bf16_tofile(w_uv, w_uv_path)
tensor_bf16_tofile(w_uv_scale_w, w_uv_scale_w_path)
w_o_nz = w_o.reshape(input_n * v_head_dim, input_h // 32, 32)
w_o_nz = w_o_nz.transpose(0, 1)
tensor_bf16_tofile(w_o_nz, w_o_path)
tensor_bf16_tofile(w_o, w_o_nd_path)
tensor_bf16_tofile(w_o_scale_w, w_o_scale_w_path)
cast1 = input_t.to(dtype)
tensor_bf16_tofile(cast1, cast1_path)
r1 = cast1.reshape(input_b * input_s, input_n, kv_lora_rank)
tensor_bf16_tofile(r1, r1_path)
t1 = r1.transpose(0, 1)
tensor_bf16_tofile(t1, t1_path)
calc_input = t1
bmm4 = torch.matmul(calc_input.to(torch.float32), w_uv.to(torch.float32))
if dtype != torch.float32:
bmm4 = bmm4.to(dtype)
tensor_bf16_tofile(bmm4, bmm4_path)
t3 = bmm4.transpose(0, 1)
tensor_bf16_tofile(t3, t3_path)
r2 = t3.reshape(input_b * input_s, input_n * v_head_dim)
tensor_bf16_tofile(r2, r2_path)
bmm5_i = r2
bmm5 = gen_quant_mm_torch(bmm5_i, w_o, w_o_scale_w, output)
tensor_bf16_tofile(bmm5, bmm5_path)
bmm5 = bmm5.reshape(input_b, input_s, input_h)
tensor_bf16_tofile(bmm5, attn_output_path)
return True
def gen_data_func_bf16_quant_onlymm5(shape_size, dtype, case_name: str, output: Path):
input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim = shape_size
params_path = Path(output, 'params.bin')
input_path = Path(output, 'input.bin')
t1_path = Path(output, 't1.bin')
r1_path = Path(output, 'r1.bin')
t2_path = Path(output, 't2.bin')
w_uv_path = Path(output, 'w_uv.bin')
w_uv_scale_w_path = Path(output, 'w_uv_scale_w.bin')
cast0_out_path = Path(output, 'cast0_out.bin')
abs_out_path = Path(output, 'abs_out.bin')
mul0_out_path = Path(output, 'mul0_out.bin')
rms_out_path = Path(output, 'rms_out.bin')
quant1_int8_path = Path(output, 'quant0_int8.bin')
quant1_fp32_path = Path(output, 'quant0_fp32.bin')
bmm4_int32_path = Path(output, 'bmm4_int32.bin')
bmm4_path = Path(output, 'bmm4.bin')
t3_path = Path(output, 't3.bin')
r2_path = Path(output, 'r2.bin')
w_o_path = Path(output, 'w_o.bin')
w_o_scale_w_path = Path(output, 'w_o_scale_w.bin')
bmm5_path = Path(output, 'bmm5.bin')
attn_output_path = Path(output, 'attn_output.bin')
complete = (params_path.exists() and input_path.exists() and t1_path.exists() and r1_path.exists()
and t2_path.exists() and w_uv_path.exists() and bmm4_path.exists() and t3_path.exists()
and r2_path.exists() and w_o_path.exists() and attn_output_path.exists() and bmm5_path.exists()
and w_uv_scale_w_path.exists() and w_o_scale_w_path.exists())
complete = False
if complete:
logging.debug("Case(%s), Golden complete.", case_name)
else:
dtype_num = 0
if dtype == torch.float32:
dtype_num = 0
elif dtype == torch.float16:
dtype_num = 1
elif dtype == torch.bfloat16:
dtype_num = 2
params = torch.tensor([input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim, dtype_num],
dtype=torch.int64)
input_t = torch.randn([input_b, input_n, input_s, kv_lora_rank], dtype=dtype)
w_uv = torch.randn([input_n, kv_lora_rank, v_head_dim], dtype=dtype)
w_uv_scale_w = torch.randn([input_n, 1, v_head_dim], dtype=torch.float32) * 0.001
w_o = torch.randint(size=(input_n * v_head_dim, input_h), low=-128, high=128, dtype=torch.int8)
w_o_scale_w = torch.randn([input_h], dtype=torch.float32) * 0.001
params.numpy().tofile(params_path)
tensor_bf16_tofile(input_t, input_path)
tensor_bf16_tofile(w_uv, w_uv_path)
tensor_bf16_tofile(w_uv_scale_w, w_uv_scale_w_path)
w_o_nz = w_o.reshape(input_n * v_head_dim, input_h // 32, 32)
w_o_nz = w_o_nz.transpose(0, 1)
tensor_bf16_tofile(w_o_nz, w_o_path)
tensor_bf16_tofile(w_o_scale_w, w_o_scale_w_path)
t1 = input_t.transpose(1, 2)
tensor_bf16_tofile(t1, t1_path)
r1 = t1.reshape(input_b * input_s, input_n, kv_lora_rank)
tensor_bf16_tofile(r1, r1_path)
t2 = r1.transpose(0, 1)
tensor_bf16_tofile(t2, t2_path)
calc_input = t2
bmm4 = torch.matmul(calc_input.to(torch.float32), w_uv.to(torch.float32))
if dtype != torch.float32:
bmm4 = bmm4.to(dtype)
tensor_bf16_tofile(bmm4, bmm4_path)
t3 = bmm4.transpose(0, 1)
tensor_bf16_tofile(t3, t3_path)
r2 = t3.reshape(input_b * input_s, input_n * v_head_dim)
tensor_bf16_tofile(r2, r2_path)
bmm5_i = r2
bmm5 = gen_quant_mm_torch(bmm5_i, w_o, w_o_scale_w, output)
tensor_bf16_tofile(bmm5, bmm5_path)
bmm5 = bmm5.reshape(input_b, input_s, input_h)
tensor_bf16_tofile(bmm5, attn_output_path)
return True
def attention_post_func_torch_bf16(case_name: str, output: Path):
if case_name == "OnBoardTest.test_attention_post_bf16_real_batch4":
input_b = 4
input_s = 1
input_n = 32
input_h = 7168
kv_lora_rank = 512
v_head_dim = 128
dtype = torch.bfloat16
gen_data_func_bf16((input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim), dtype, case_name, output)
elif case_name == "OnBoardTest.test_attention_post_bf16_real_n128":
input_b = 32
input_s = 1
input_n = 128
input_h = 7168
kv_lora_rank = 512
v_head_dim = 128
dtype = torch.bfloat16
gen_data_func_bf16((input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim), dtype, case_name, output)
def attention_post_func_torch_mm5(case_name: str, output: Path):
if (case_name == "OnBoardCostTest.test_attention_post_bf16_real_quant_n128_onlymm5"
or case_name == "OnBoardCostTest.test_attention_post_bf16_real_quant_n128_onlymm5K"
or case_name == "OnBoardCostTest.test_attention_post_bf16_real_quant_n128_onlymm5He"):
input_b = 32
input_s = 1
input_n = 128
input_h = 7168
kv_lora_rank = 512
v_head_dim = 128
dtype = torch.bfloat16
gen_data_func_bf16_quant_onlymm5((input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim), dtype,
case_name, output)
elif (case_name == "OnBoardCostTest.test_attention_post_bf16_real_quant_batch4_onlymm5"
or case_name == "OnBoardCostTest.test_attention_post_bf16_real_quant_batch4_onlymm5K"):
input_b = 4
input_s = 1
input_n = 32
input_h = 7168
kv_lora_rank = 512
v_head_dim = 128
dtype = torch.bfloat16
gen_data_func_bf16_quant_onlymm5((input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim), dtype,
case_name, output)
def attention_post_func_numpy_b32(case_name: str, output: Path):
if (case_name == "DynamicAttentionPostTest.dynamic_pa_post_new_mm5nz_splitk"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_new_mm5nz_normal_unsplitk"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_new_mm5nd_splitk"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_r2"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_t3r2"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_crtb4tr"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_r1"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlyt1"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_t1"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlybmm4_fail"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlybmm4"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_bmm4"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_t3"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_quant"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_crtb4tr_quant_fail"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_crtb4tr_quant"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_crtb4trq_mm5nd"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_crtb4trq_mm5ndk"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_unquant_r3"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlymm5_nd"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlymm5_nz"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlymm5_ndk"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_mm5ndk_unquant_r3"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlymm5_nzk"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_new_mm5nd_unsplitk"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_new_mm5nz_unsplitk"):
input_b = 32
input_s = 1
input_n = 128
input_h = 7168
kv_lora_rank = 512
v_head_dim = 128
dtype = bfloat16
params = {
"b": input_b,
"s": input_s,
"h": input_h,
"num_heads": input_n,
"kv_lora_rank": kv_lora_rank,
"v_head_dim": v_head_dim,
}
gen_post_data(output, params, dtype, None, True, True)
def attention_post_func_numpy_b4(case_name: str, output: Path):
if (case_name == "DynamicAttentionPostTest.dynamic_pa_post_new_mm5nz_splitk_low"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_new_mm5nd_splitk_low"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_new_mm5nz_unsplitk_low"
or case_name == "DynamicAttentionPostTest.dynamic_pa_post_new_mm5nd_unsplitk_low"):
input_b = 4
input_s = 1
input_n = 32
input_hh = 7168
kv_lora_rank1 = 512
v_head_dim1 = 128
dtype1 = bfloat16
params1 = {
"b": input_b,
"s": input_s,
"h": input_hh,
"num_heads": input_n,
"kv_lora_rank": kv_lora_rank1,
"v_head_dim": v_head_dim1,
}
gen_post_data(output, params1, dtype1, None, True, True)
@GoldenRegister.reg_golden_func(
case_names=[
"OnBoardTest.test_attention_post_bf16_real_batch4",
"OnBoardTest.test_attention_post_bf16_real_n128",
"OnBoardCostTest.test_attention_post_bf16_real_quant_n128_onlymm5",
"OnBoardCostTest.test_attention_post_bf16_real_quant_batch4_onlymm5",
"OnBoardCostTest.test_attention_post_bf16_real_quant_n128_onlymm5K",
"OnBoardCostTest.test_attention_post_bf16_real_quant_batch4_onlymm5K",
"OnBoardCostTest.test_attention_post_bf16_real_quant_n128_onlymm5He",
"OnBoardCostTest.dynamic_pa_post_static_cast_first",
"DynamicAttentionPostTest.dynamic_pa_post_new_mm5nz_splitk",
"DynamicAttentionPostTest.dynamic_pa_post_new_mm5nz_normal_unsplitk",
"DynamicAttentionPostTest.dynamic_pa_post_new_mm5nd_splitk",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_r2",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_t3r2",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_crtb4tr",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_r1",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlyt1",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_t1",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlybmm4_fail",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlybmm4",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_bmm4",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_t3",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_quant",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_crtb4tr_quant_fail",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_crtb4tr_quant",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_crtb4trq_mm5nd",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_crtb4trq_mm5ndk",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_unquant_r3",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlymm5_nd",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlymm5_nz",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlymm5_ndk",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_mm5ndk_unquant_r3",
"DynamicAttentionPostTest.dynamic_pa_post_cast_first_onlymm5_nzk",
"DynamicAttentionPostTest.dynamic_pa_post_new_mm5nd_unsplitk",
"DynamicAttentionPostTest.dynamic_pa_post_new_mm5nz_unsplitk",
"DynamicAttentionPostTest.dynamic_pa_post_new_mm5nz_splitk_low",
"DynamicAttentionPostTest.dynamic_pa_post_new_mm5nd_splitk_low",
"DynamicAttentionPostTest.dynamic_pa_post_new_mm5nz_unsplitk_low",
"DynamicAttentionPostTest.dynamic_pa_post_new_mm5nd_unsplitk_low",
"DynamicAttentionPostTest.dynamic_pa_papost_bf16_b48",
]
)
def attention_post_func(case_name: str, output: Path) -> bool:
attention_post_func_torch_bf16(case_name, output)
attention_post_func_torch_mm5(case_name, output)
attention_post_func_numpy_b32(case_name, output)
attention_post_func_numpy_b4(case_name, output)
if case_name == "OnBoardCostTest.dynamic_pa_post_static_cast_first":
input_b = 32
input_s = 1
input_n = 128
input_h = 7168
kv_lora_rank = 512
v_head_dim = 128
dtype = torch.bfloat16
gen_data_func_bf16_quant_dynamic_cast((input_b, input_s, input_n, input_h, kv_lora_rank, v_head_dim),
dtype, case_name, output)
elif (case_name == "DynamicAttentionPostTest.dynamic_pa_papost_bf16_b48"):
b = 48
n_q = 128
skv = 4096
block_size = 4096
n_tile = 128
s_q = 1
n_kv = 1
kv_lora_rank = 512
qk_rope_dim = 64
dtype = bfloat16
input_h = 7168
v_head_dim = 128
params = {
"b": b,
"s": s_q,
"s2": skv,
"num_heads": n_q,
"n2": n_kv,
"qk_rope_head_dim": qk_rope_dim,
"kv_lora_rank": kv_lora_rank,
"h": input_h,
"v_head_dim": v_head_dim,
}
pa_out = gen_pa_data(output, params, dtype, None, None, None, None, block_size, n_tile, False)
gen_post_data(output, params, dtype, pa_out, True, True)
else:
return True
return True
def main() -> bool:
"""
单独调试 入口函数
"""
case_name_list: List[str] = [
"OnBoardTest.test_attention_post_bf16_real_n128",
]
ret: bool = True
for cs in case_name_list:
output: Path = Path(g_src_root, "build/output/bin/golden", cs).resolve()
output.mkdir(parents=True, exist_ok=True)
ret = attention_post_func(case_name=cs, output=output)
return ret
if __name__ == "__main__":
exit(0 if main() else 1)