"""
本脚本有 2 种执行模式:
CI批跑时, 由 cmake/scripts/golden_ctrl.py 调用, 为避免日志过多, 此时 logging 级别为 logging.INFO;
单独调试时, 本脚本单独被调用, 此时 logging 级别为 logging.DEBUG;
"""
import math
import sys
import logging
from pathlib import Path
from typing import List
import numpy as np
from ml_dtypes import bfloat16
if __name__ == "__main__":
""" 单独调试时配置 """
logging.basicConfig(format='%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s: %(message)s',
level=logging.DEBUG)
g_src_root: Path = Path(Path(file).parent, "../../../../../").resolve()
logging.debug("SrcRoot: %s", g_src_root)
g_ctrl_path: Path = Path(g_src_root, "cmake/scripts")
if str(g_ctrl_path) not in sys.path:
sys.path.append(str(g_ctrl_path))
from golden_register import GoldenRegister
else:
from golden_register import GoldenRegister
def dump_file(data_pool, data_path, type_str):
if type_str.lower() == 'fp16':
np.array(data_pool).astype(np.float16).tofile(data_path)
elif type_str.lower() == 'fp32':
np.array(data_pool).astype(np.float32).tofile(data_path)
elif type_str.lower() == 'fp64':
np.array(data_pool).astype(np.float64).tofile(data_path)
elif type_str.lower() == 'int8':
np.array(data_pool).astype(np.int8).tofile(data_path)
elif type_str.lower() == 'int16':
np.array(data_pool).astype(np.int16).tofile(data_path)
elif type_str.lower() == 'int32':
np.array(data_pool).astype(np.int32).tofile(data_path)
elif type_str.lower() == 'int64':
np.array(data_pool).astype(np.int64).tofile(data_path)
elif type_str.lower() == 'uint8':
np.array(data_pool).astype(np.uint8).tofile(data_path)
elif type_str.lower() == 'uint16':
np.array(data_pool).astype(np.uint16).tofile(data_path)
elif type_str.lower() == 'uint32':
np.array(data_pool).astype(np.uint32).tofile(data_path)
elif type_str.lower() == 'uint64':
np.array(data_pool).astype(np.uint64).tofile(data_path)
elif type_str.lower() == 'complex64':
np.array(data_pool).astype(np.complex64).tofile(data_path)
elif type_str.lower() == 'complex128':
np.array(data_pool).astype(np.complex128).tofile(data_path)
elif type_str.lower() == 'bool':
np.array(data_pool).astype(np.bool_).tofile(data_path)
elif type_str.lower() == 'bf16':
np.array(data_pool).astype(bfloat16).tofile(data_path)
def gen_uniform_data(data_shape, min_value, max_value, dtype):
if min_value == 0 and max_value == 0:
return np.zeros(data_shape, dtype=dtype)
if dtype == np.bool_:
return np.random.choice([True, False], size=data_shape)
return np.random.uniform(low=min_value, high=max_value, size=data_shape).astype(
dtype
)
def numpy_topk(input_array, k, axis=-1):
"""
实现类似PyTorch的torch.topk功能,返回指定维度上的前k个最大值及其索引。
参数:
input_array (np.ndarray): 输入数组
k (int): 需要提取的最大值的数量
axis (int): 操作的维度,默认为最后一个维度
返回:
values (np.ndarray): 前k个最大值
indices (np.ndarray): 对应的索引
"""
if k <= 0:
raise ValueError("k必须为正整数")
partitioned_indices = np.argpartition(input_array, -k, axis=axis)[..., -k:]
partitioned_values = np.take_along_axis(input_array, partitioned_indices, axis=axis)
sorted_order = np.argsort(-partitioned_values, axis=axis)
final_indices = np.take_along_axis(partitioned_indices, sorted_order, axis=axis)
final_values = np.take_along_axis(input_array, final_indices, axis=axis)
return final_values, final_indices
def kv_slc_compute(compute_input_params, topk_indecies, topk_tensor_shape, kvNopeCache, krCache, block_table, actual_seq_len):
block_size = compute_input_params[0]
n2 = compute_input_params[1]
front = compute_input_params[2]
near = compute_input_params[3]
topK = compute_input_params[4]
l_prime = compute_input_params[5]
b = topk_indecies.shape[0]
s = topk_indecies.shape[1]
rope_dim = krCache.shape[1]
kv_lora_rank = kvNopeCache.shape[1]
kv_cache_axis1 = kvNopeCache.shape[0]
shape_k_slc_out = [b * n2 * s * topK * l_prime, rope_dim + kv_lora_rank]
shape_v_slc_out = [b * n2 * s * topK * l_prime, kv_lora_rank]
k_slc_out = np.zeros(shape_k_slc_out, kvNopeCache.dtype)
v_slc_out = np.zeros(shape_v_slc_out, kvNopeCache.dtype)
kv_slc_actual_seqs = np.zeros([b, s], dtype=np.int32)
for batchIdx in range(b):
for seqIdx in range(s):
slcSeqLen = 0
s_slc = topk_tensor_shape[batchIdx][seqIdx]
for nkvIdx in range(n2):
for topKIdx in range(topK):
if topKIdx < front:
position = topKIdx
elif topKIdx > topK - near - front:
position = s_slc - near + (topKIdx - (topK - front - near) - 1)
else:
position = topk_indecies[batchIdx][seqIdx][topKIdx - front]
block_idx_in_batch = int(position * l_prime / block_size)
tail = int(position * l_prime % block_size)
slcBlockIdx = block_table[batchIdx][block_idx_in_batch]
slcSeqLen = slcSeqLen + max(l_prime - max(position * l_prime + l_prime - actual_seq_len[batchIdx], 0), 0)
preIdx_out_base = batchIdx * s * n2 * topK * l_prime + seqIdx * n2 * topK * l_prime + nkvIdx * topK * l_prime + topKIdx * l_prime
preIdx_cache_base = slcBlockIdx * block_size + tail
k_slc_out[preIdx_out_base : preIdx_out_base + l_prime, 0:kv_lora_rank] = kvNopeCache[preIdx_cache_base : preIdx_cache_base + l_prime, 0:kv_lora_rank]
k_slc_out[preIdx_out_base : preIdx_out_base + l_prime, kv_lora_rank:kv_lora_rank + rope_dim] = krCache[preIdx_cache_base : preIdx_cache_base + l_prime, 0:rope_dim]
v_slc_out[preIdx_out_base : preIdx_out_base + l_prime, 0:kv_lora_rank] = kvNopeCache[preIdx_cache_base : preIdx_cache_base + l_prime, 0:kv_lora_rank]
kv_slc_actual_seqs[batchIdx][seqIdx] = slcSeqLen
return k_slc_out, v_slc_out, kv_slc_actual_seqs
def gen_block_table(b, actual_seq_len, block_size, output: Path):
block_num_per_batch = []
block_num_min = 0
block_num = 0
for actual_seq in actual_seq_len:
block_num_per_batch.append(math.ceil(actual_seq / block_size))
block_num_min += math.ceil(actual_seq / block_size)
s_max = max(actual_seq_len)
block_table_shape = [b, math.ceil(s_max / block_size)]
block_num = block_num_min
block_idx_list = np.arange(0, block_num, 1)
block_idx_list = np.random.permutation(block_idx_list).astype(np.int32)
block_idx = 0
block_table = [-1] * block_table_shape[1]
block_table = np.tile(block_table, (block_table_shape[0], 1)).astype(np.int32)
block_table_batch_idx = 0
for idx in block_num_per_batch:
block_idx = 0
for j in range(idx):
block_table[block_table_batch_idx][j] = (block_idx_list[block_idx])
block_idx += 1
block_table_batch_idx += 1
logging.debug("block_table %s", block_table)
block_table_path = Path(output, 'block_table.bin')
dump_file(block_table, block_table_path, "int32")
return block_num, block_table
def gen_i_o_tensor(input_param, s_slc, s2, dtype, output: Path):
block_size = input_param[9]
b = input_param[0]
s = input_param[1]
n2 = input_param[2]
kv_lora_rank = input_param[3]
rope_dim = input_param[4]
front = input_param[5]
near = input_param[6]
topK = input_param[7]
l_prime = input_param[8]
actual_seq_len = [s2] * b
actual_seq_len_path = Path(output, 'actual_seq_len.bin')
dump_file(actual_seq_len, actual_seq_len_path, "int32")
block_num, block_table = gen_block_table(b, actual_seq_len, block_size, output)
shape_topk_indecies = [b, s, topK - front - near]
shape_kvNopeCache = [block_num * block_size, n2 * kv_lora_rank]
shape_krCache = [block_num * block_size, n2 * rope_dim]
topk_indecies = gen_uniform_data(shape_topk_indecies, 0, s_slc, dtype=np.int32)
topk_tensor_shape = np.zeros([b, s], dtype=np.int32)
for batchIdx in range(b):
for seqIdx in range(s):
topk_tensor_shape[batchIdx][seqIdx] = s_slc
kvNopeCache = gen_uniform_data(shape_kvNopeCache, -1, 1, dtype)
krCache = gen_uniform_data(shape_krCache, -1, 1, dtype)
kv_slc_actual_seqs = np.zeros([b, s], dtype=np.int32)
if dtype == bfloat16:
dump_dtype = "bf16"
if dtype == np.float16:
dump_dtype = "fp16"
topk_tensor_path = Path(output, 'topk_tensor.bin')
topk_indecies_path = Path(output, 'topk_tensor.bin')
kv_nope_cache_path = Path(output, 'kv_nope_cache.bin')
kr_cache_path = Path(output, 'k_rope_cache.bin')
topk_tensor_shape_path = Path(output, 'topk_tensor_shape.bin')
dump_file(topk_indecies, topk_tensor_path, "int32")
dump_file(topk_tensor_shape, topk_tensor_shape_path, "int32")
dump_file(kvNopeCache, kv_nope_cache_path, dump_dtype)
dump_file(krCache, kr_cache_path, dump_dtype)
shape_k_slc_out = [b * n2 * s * topK * l_prime, rope_dim + kv_lora_rank]
shape_v_slc_out = [b * n2 * s * topK * l_prime, kv_lora_rank]
k_slc_out = np.zeros(shape_k_slc_out, dtype)
v_slc_out = np.zeros(shape_v_slc_out, dtype)
compute_input_params = [block_size, n2, front, near, topK, l_prime]
k_slc_out, v_slc_out, kv_slc_actual_seqs = kv_slc_compute(compute_input_params, topk_indecies, topk_tensor_shape, kvNopeCache, krCache, block_table, actual_seq_len)
k_slc_out_path = Path(output, 'k_slc_out.bin')
v_slc_out_path = Path(output, 'v_slc_out.bin')
kv_slc_actual_seqs_path = Path(output, 'kv_slc_actual_seqs.bin')
dump_file(k_slc_out, k_slc_out_path, dump_dtype)
dump_file(v_slc_out, v_slc_out_path, dump_dtype)
dump_file(kv_slc_actual_seqs, kv_slc_actual_seqs_path, "int32")
@GoldenRegister.reg_golden_func(
case_names=[
"DynamicSlcTest.dynamic_p_slc_fp16",
"DynamicSlcTest.dynamic_p_slc_bf16",
]
)
def kv_slc_func(case_name: str, output: Path) -> bool:
gen_data_debug_mode = False
if case_name.startswith('DynamicSlcTest.dynamic_p_slc_fp16'):
block_size = 128
b = 32
s = 1
s_slc = 128
n2 = 1
s2 = 8192
kv_lora_rank = 512
rope_dim = 64
front = 1
near = 2
topK = 16
l_prime = 64
golden_input_params = [b, s, n2, kv_lora_rank, rope_dim, front, near, topK, l_prime, block_size]
dtype = np.float16
if case_name.startswith('DynamicSlcTest.dynamic_p_slc_bf16'):
block_size = 128
b = 16
s = 1
s_slc = 32
n2 = 1
s2 = 4096
kv_lora_rank = 512
rope_dim = 64
front = 1
near = 2
topK = 16
l_prime = 64
golden_input_params = [b, s, n2, kv_lora_rank, rope_dim, front, near, topK, l_prime, block_size]
dtype = bfloat16
input_param_path = Path(output, 'input_param.bin')
dump_file(golden_input_params, input_param_path, "int32")
gen_i_o_tensor(golden_input_params, s_slc, s2, dtype, output=output)
return True
def main() -> bool:
"""
单独调试 入口函数
"""
case_name_list: List[str] = [
"DynamicSlcTest.dynamic_p_slc_fp16",
"DynamicSlcTest.dynamic_p_slc_bf16",
]
ret: bool = True
for cs in case_name_list:
output: Path = Path(g_src_root, "build/output/bin/golden", cs).resolve()
output.mkdir(parents=True, exist_ok=True)
ret = kv_slc_func(case_name=cs, output=output)
return ret
if __name__ == "__main__":
exit(0 if main() else 1)