import sys
import logging
from pathlib import Path
from typing import List
import numpy as np
from ml_dtypes import bfloat16
if __name__ == "__main__":
""" 单独调试时配置 """
logging.basicConfig(format='%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s: %(message)s',
level=logging.DEBUG)
g_src_root: Path = Path(Path(__file__).parent, "../../../../../").resolve()
logging.debug("SrcRoot: %s", g_src_root)
g_ctrl_path: Path = Path(g_src_root, "cmake/scripts")
if str(g_ctrl_path) not in sys.path:
sys.path.append(str(g_ctrl_path))
from golden_register import GoldenRegister
else:
from golden_register import GoldenRegister
def dump_file(data_pool, data_path, type_str):
if type_str.lower() == 'fp16':
np.array(data_pool).astype(np.float16).tofile(data_path)
elif type_str.lower() == 'fp32':
np.array(data_pool).astype(np.float32).tofile(data_path)
elif type_str.lower() == 'fp64':
np.array(data_pool).astype(np.float64).tofile(data_path)
elif type_str.lower() == 'int8':
np.array(data_pool).astype(np.int8).tofile(data_path)
elif type_str.lower() == 'int16':
np.array(data_pool).astype(np.int16).tofile(data_path)
elif type_str.lower() == 'int32':
np.array(data_pool).astype(np.int32).tofile(data_path)
elif type_str.lower() == 'int64':
np.array(data_pool).astype(np.int64).tofile(data_path)
elif type_str.lower() == 'uint8':
np.array(data_pool).astype(np.uint8).tofile(data_path)
elif type_str.lower() == 'uint16':
np.array(data_pool).astype(np.uint16).tofile(data_path)
elif type_str.lower() == 'uint32':
np.array(data_pool).astype(np.uint32).tofile(data_path)
elif type_str.lower() == 'uint64':
np.array(data_pool).astype(np.uint64).tofile(data_path)
elif type_str.lower() == 'complex64':
np.array(data_pool).astype(np.complex64).tofile(data_path)
elif type_str.lower() == 'complex128':
np.array(data_pool).astype(np.complex128).tofile(data_path)
elif type_str.lower() == 'bool':
np.array(data_pool).astype(np.bool_).tofile(data_path)
elif type_str.lower() == 'bf16':
np.array(data_pool).astype(bfloat16).tofile(data_path)
def gen_uniform_data(data_shape, min_value, max_value, dtype):
if min_value == 0 and max_value == 0:
return np.zeros(data_shape, dtype=dtype)
if dtype == np.bool_:
return np.random.choice([True, False], size=data_shape)
return np.random.uniform(low=min_value, high=max_value, size=data_shape).astype(
dtype
)
@GoldenRegister.reg_golden_func(
case_names=[
"DynamicBasicTest.TestInnerLoopOrder",
]
)
def gen_dynamic_basic_op_golden(case_name: str, output: Path) -> bool:
if case_name == "DynamicBasicTest.TestInnerLoopOrder":
dtype = np.float32
vec_len = 16
loop_num = 4
tile_num = 3
shape_input_a = [loop_num, vec_len]
shape_input_b = shape_input_a
shape_out = [tile_num, vec_len]
input_a = gen_uniform_data(shape_input_a, -1, 1, dtype)
input_b = gen_uniform_data(shape_input_b, -1, 1, dtype)
out = np.zeros(shape_out).astype(dtype)
for i in range(tile_num):
tile_b = input_b[i : i + 1, 0 : vec_len] * 2.0
for k in range(loop_num):
tile_a = input_a[k : k + 1, :]
tile_b = tile_a + tile_b
tile_b = tile_b * 3.0
out[i : i + 1] = tile_b
input_a.tofile(Path(output, 'input_a.bin'))
input_b.tofile(Path(output, 'input_b.bin'))
out.tofile(Path(output, 'out.bin'))
else:
logging.error("Can't get func to gen golden, Case(%s)", case_name)
return False
return True
def main() -> bool:
"""
入口函数
"""
case_name_list: List[str] = [
"DynamicBasicTest.TestInnerLoopOrder",
]
ret: bool = True
for cs in case_name_list:
output: Path = Path(g_src_root, "build/output/bin/golden", cs).resolve()
output.mkdir(parents=True, exist_ok=True)
ret = gen_dynamic_basic_op_golden(case_name=cs, output=output)
return ret
if __name__ == "__main__":
exit(0 if main() else 1)