import os
import copy
import struct
import numpy as np
import ml_dtypes
bfloat16 = ml_dtypes.bfloat16
np.random.seed(19)
def zero_pad(arr, target_shape, dtype):
padded = np.zeros(target_shape, dtype=dtype)
padded[:arr.shape[0], :arr.shape[1]] = arr
return padded
def saturation(value, min_val, max_val, target_type):
"""
Perform saturation processing on the input floating-point number and convert it to the target type.
"""
x_clamped = np.clip(value, min_val, max_val)
return np.round(x_clamped).astype(target_type)
def extract_quant_params(quant_gm):
"""
Extract the parameters M1, offset, and sign from the quant_gm of type uint64.
Args:
quant_g: An integer of type uint64
Return:
m1: A floating-point number in custom format (1,8,10)
offset: A 9-bit integer
sign: A 1-bit boolean value (0 or 1)
"""
quant_gm = int(quant_gm)
m1_bits = (quant_gm >> 13) & 0x7FFFF
offset = (quant_gm >> 37) & 0x1FF
sign = (quant_gm >> 46) & 0x1
sign_bit = (m1_bits >> 18) & 0x1
exponent = (m1_bits >> 10) & 0xFF
mantissa = m1_bits & 0x3FF
exponent_bias = 127
m1 = (-1) ** sign_bit * (1 + mantissa / 1024) * (2 ** (exponent - exponent_bias))
return m1, offset, sign
def qf2b8_pre(data, quant_gm):
"""
float32 -> int8
int32 ->int8
"""
m1, offset, sign = extract_quant_params(quant_gm)
tmp1 = saturation(data.astype(np.float32) * m1, -256, 255, np.int16) + offset
if sign:
return saturation(tmp1, -128, 127, np.int8)
else:
return saturation(tmp1, 0, 255, np.uint8)
def qf2f16_pre(data, quant_gm):
"""
float32 -> float16
"""
m1, offset, sign = extract_quant_params(quant_gm)
return saturation(data.astype(np.float32) * m1, np.finfo(np.float16).min, np.finfo(np.float16).max, np.float16)
def qf2bf16_pre(data, quant_gm):
"""
float32 -> bfloat16
"""
m1, offset, sign = extract_quant_params(quant_gm)
return saturation(data.astype(np.float32) * m1, 0x0080, 0x7F80, bfloat16)
def get_vector_quant(golden, m, n, dst_type, quant_type):
temp_quant_tensor = np.random.randint(1, 5, n).astype(np.float32)
temp_quant_tensor_api = copy.deepcopy(temp_quant_tensor).astype(np.uint64)
for i, _ in enumerate(temp_quant_tensor_api):
temp_quant_tensor_api[i] = struct.unpack('!I', struct.pack('!f', temp_quant_tensor[i]))[0]
if dst_type == np.int8:
temp_quant_tensor_api[i] = temp_quant_tensor_api[i] | np.uint64(0x400000000000)
quant_tensor = np.frombuffer(temp_quant_tensor_api, np.uint64)
quant_tensor = quant_tensor.astype(quant_type)
quant_tensor.tofile("./quant_gm.bin")
quant_golden = np.zeros((m, n), dtype=dst_type)
for i in range(m):
for j in range(n):
if dst_type == np.int8:
quant_golden[i, j] = qf2b8_pre(golden[i, j], quant_tensor[j])
elif dst_type == np.float16:
quant_golden[i, j] = qf2f16_pre(golden[i, j], quant_tensor[j])
elif dst_type == bfloat16:
quant_golden[i, j] = qf2bf16_pre(golden[i, j], quant_tensor[j])
else:
quant_golden[i, j] = golden[i, j] * quant_tensor[j]
return quant_golden
def gen_golden_data(case_name, param):
a_type, b_type, c_type, dst_type = param.atype, param.btype, param.ctype, param.dst_type
m, k, n = param.m, param.k, param.n
base_m, base_k, base_n = param.base_m, param.base_k, param.base_n
s_fractal_size = param.s_fractal_size if hasattr(param, 's_fractal_size') else 512
dst_format = param.dst_format if hasattr(param, 'dst_format') else 'ND'
base_m = base_m if base_m > m else m
base_k = base_k if base_k > k else k
base_n = base_n if base_n > n else n
if dst_type == np.int8:
x1_gm = np.random.randint(-1, 2, [m, k]).astype(a_type)
x2_gm = np.random.randint(-1, 2, [k, n]).astype(a_type)
else:
x1_gm = np.random.randint(1, 5, [m, k]).astype(a_type)
x2_gm = np.random.randint(1, 5, [k, n]).astype(b_type)
x1_gm_padded = zero_pad(x1_gm, (base_m, base_k), a_type)
x2_gm_padded = zero_pad(x2_gm, (base_k, base_n), a_type)
x1_gm_padded.tofile("./x1_gm.bin")
x2_gm_padded.tofile("./x2_gm.bin")
golden = np.matmul(x1_gm_padded.astype(c_type), x2_gm_padded.astype(c_type)).astype(c_type)
if param.is_v_quant:
golden = get_vector_quant(golden, base_m, base_n, dst_type, param.quant_type)
elif param.is_s_quant:
golden = golden * param.scalar
if dst_type == np.int8:
golden = saturation(golden, -128, 127, np.int8)
elif dst_type == np.uint8:
golden = saturation(golden, 0, 255, np.uint8)
if param.is_relu:
golden = np.maximum(golden, 0)
if param.index_rows == 0 and param.index_cols == 0:
dst_data = np.zeros((m, n), dtype=dst_type)
else:
dst_data = np.zeros((param.dst_row, param.dst_col), dtype=dst_type)
dst_data.astype(dst_type).tofile("./dst.bin")
if param.index_rows != 0 or param.index_cols != 0:
dstm = min(param.dst_row, param.index_rows + m)
dstn = min(param.dst_col, param.index_cols + n)
srcm = min(m, param.dst_row - param.index_rows)
srcn = min(n, param.dst_col - param.index_cols)
dst_data[param.index_rows:dstm, param.index_cols:dstn] = golden[0:srcm, 0:srcn]
dst_data.astype(dst_type).tofile("./golden.bin")
else:
golden.astype(dst_type).tofile("./golden.bin")
class TMovParams:
def __init__(self, atype, btype, dst_type, m, k, n, base_m=0, base_k=0, base_n=0,
dst_format='ND', s_fractal_size=512, is_v_quant=False,
is_s_quant=False, is_relu=False, quant_type=None, scalar=1,
index_rows=0, index_cols=0, is_insert=False, dst_row=0, dst_col=0):
self.atype = atype
self.btype = btype
self.ctype = np.float32
if (atype == np.int8):
self.ctype = np.int32
self.dst_type = dst_type
self.m = m
self.k = k
self.n = n
self.base_m = base_m
self.base_k = base_k
self.base_n = base_n
self.dst_format = dst_format
self.s_fractal_size = s_fractal_size
self.is_v_quant = is_v_quant
self.is_s_quant = is_s_quant
self.is_relu = is_relu
if (is_v_quant):
self.quant_type = quant_type
if (is_s_quant):
self.scalar = scalar
self.index_rows = index_rows
self.index_cols = index_cols
self.is_insert = is_insert
self.dst_row = dst_row
self.dst_col = dst_col
if __name__ == "__main__":
case_name_list = [
"TMOVTest.case_nz2nd_1",
"TMOVTest.case_nz2nd_2",
"TMOVTest.case_nz2nd_3",
"TMOVTest.case_nz2nd_4",
"TMOVTest.case_nz2nd_split_1",
"TMOVTest.case_nz2nd_split_2",
"TMOVTest.case_nz2nd_fb_quant_1",
"TMOVTest.case_nz2nd_fb_quant_2",
"TMOVTest.case_nz2nd_fb_quant_3",
"TMOVTest.case_nz2nd_fb_quant_4",
"TMOVTest.case_nz2nd_fb_quant_5",
"TMOVTest.case_nz2nd_sc_quant_1",
"TMOVTest.case_nz2nd_sc_quant_2",
"TMOVTest.case_nz2nd_sc_quant_3",
"TMOVTest.case_nz2nd_sc_quant_4",
]
case_params_list = [
TMovParams(np.float16, np.float16, np.float32, 60, 127, 120, 0, 0, 0, 'ND', 512, False, False, True, None, 1,
0, 8, True, 60, 128),
TMovParams(np.float16, np.float16, np.float16, 110, 100, 80, 0, 0, 0, 'ND', 512, False, False, True, None, 1,
5, 0, True, 120, 96),
TMovParams(np.float16, np.float16, np.float32, 6, 7, 8, 0, 0, 0, 'ND', 512, False, False, True, None, 1,
2, 0, True, 10, 16),
TMovParams(np.float16, np.float16, bfloat16, 111, 47, 96, 0, 0, 0, 'ND', 512, False, False, True, None, 1,
3, 32, True, 150, 160),
TMovParams(np.float16, np.float16, np.float32, 96, 32, 48),
TMovParams(np.float16, np.float16, np.float32, 48, 32, 128),
TMovParams(np.int8, np.int8, np.int8, 30, 48, 64, 0, 0, 0, 'ND', 512, True, False, False, np.uint64, 1,
0, 32, True, 40, 96),
TMovParams(np.int8, np.int8, np.float16, 60, 128, 32, 0, 0, 0, 'ND', 512, True, False, False, np.uint64, 1,
5, 32, True, 70, 96),
TMovParams(np.int8, np.int8, bfloat16, 128, 64, 96, 0, 0, 0, 'ND', 512, True, False, False, np.uint64, 1,
0, 0, True, 128, 96),
TMovParams(np.float32, np.float32, np.int8, 60, 128, 64, 0, 0, 0, 'ND', 512, True, False, True, np.uint64, 1,
7, 64, True, 80, 256),
TMovParams(np.float32, np.float32, np.float16, 31, 128, 128, 0, 0, 0, 'ND', 512, True, False, True,
np.uint64, 1, 0, 64, True, 40, 256),
TMovParams(np.float32, np.float32, np.float16, 128, 48, 96, 0, 0, 0, 'ND', 512, False, True, True, None, 2,
0, 0, True, 128, 96),
TMovParams(np.float32, np.float32, np.int8, 60, 128, 64, 0, 0, 0, 'ND', 512, False, True, True, None, 5,
0, 32, True, 128, 128),
TMovParams(np.int8, np.int8, np.float16, 30, 48, 64, 0, 0, 0, 'ND', 512, False, True, False, None, 3,
5, 32, True, 40, 128),
TMovParams(np.int8, np.int8, np.int8, 60, 128, 32, 0, 0, 0, 'ND', 512, False, True, False, None, 1,
3, 0, True, 64, 64),
]
for i, case_name in enumerate(case_name_list):
if not os.path.exists(case_name):
os.makedirs(case_name)
original_dir = os.getcwd()
os.chdir(case_name)
gen_golden_data(case_name, case_params_list[i])
os.chdir(original_dir)