import os
import numpy as np
import ml_dtypes
import en_dtypes
bfloat16 = ml_dtypes.bfloat16
fp8_e4m3fn = ml_dtypes.float8_e4m3fn
fp8_e5m2 = ml_dtypes.float8_e5m2
float4_e2m1 = en_dtypes.float4_e2m1
float4_e1m2 = en_dtypes.float4_e1m2
np.random.seed(19)
def check(x,n):
if len(x) < n:
x = '0' * (n-len(x)) + x
elif len(x) > n:
x = x[1:]
return x
def cast(c, dtype):
if dtype == 'fp16':
c = np.array(c).astype(np.float16)
elif dtype == 'fp32':
c = np.array(c).astype(np.float32)
return c
def HF8(input):
if len(input) != 8:
print("The input must be 8-bit. Please check the input.")
exit(-1)
d, e = '', ''
s, m = input[0], input[5:]
m1, m2, m3 = int(input[5]), int(input[6]), int(input[7])
if input[1] == '1' or input[2] == '1':
d, e = input[1:3], input[3:5]
elif input[3] == '1':
d, e = input[1:4], input[4]
else:
d, e = input[1:5], ''
f1 = -1 if s == '1' else 1
f2 = 1
if d == '0000':
if s == '1':
if m == '000':
return np.nan
input = 2 ** (m1 * 4 + m2 * 2 + m3 - 23) * f1
else:
if m == '000':
return 0
input = 2 ** (m1 * 4 + m2 * 2 + m3 - 23)
return input
elif d == '0001':
f2 = 0
input = (1 + (m1 * 4 + m2 * 2 + m3)/8) * 2 ** f2 * f1
return input
elif d == '001':
f2 = -1 if e == '1' else 1
input = (1 + (m1 * 4 + m2 * 2 + m3)/8) * 2 ** f2 * f1
return input
elif d == '01':
f2 = -1 if int(input[3]) == 1 else 1
input = (1 + (m1 * 4 + m2 * 2 + m3)/8) * 2 ** (f2 * (2 + int(input[4]))) * f1
return input
elif d == '10':
f2 = -1 if int(input[3]) == 1 else 1
input = (1 + (m2 * 2 + m3)/4) * 2 ** (f2 * (4 + int(input[4]) * 2 + int(input[5]))) * f1
return input
elif d == '11':
f2 = -1 if int(input[3]) == 1 else 1
if e == '01' and m == '111':
return f1 * np.inf
input = (1 + m3/2) * 2 ** (f2 * (8 + int(input[4]) * 4 + int(input[5]) * 2 + int(input[6]))) * f1
return input
def get_hif8_golden(x1_gm, x2_gm, start_m, start_k, start_n, dst_type):
s1 = x1_gm.reshape(-1)
s2 = x2_gm.reshape(-1)
s1_len = len(s1)
s2_len = len(s2)
re1 = [0] * s1_len
re2 = [0] * s2_len
for i in range(s1_len):
temp = bin(s1[i])
temp = temp.split('b')[1]
temp = check(temp, 8)
re1[i] = HF8(temp)
s1 = cast(re1, 'fp32')
for i in range(s2_len):
temp = bin(s2[i])
temp = temp.split('b')[1]
temp = check(temp, 8)
re2[i] = HF8(temp)
s2 = cast(re2, 'fp32')
x1_gm = s1.reshape(x1_gm.shape)
x2_gm = s2.reshape(x2_gm.shape)
x1_slice = x1_gm[start_m:, start_k:]
x2_slice = x2_gm[start_k:, start_n:]
golden = np.matmul(x1_slice.astype(dst_type), x2_slice.astype(dst_type)).astype(dst_type)
return golden
def data_processing_b4(valid_m, valid_k, valid_n, x1_gm, x2_gm, x1_slice, x2_slice, dst_type):
x1_mx_gm = np.random.randint(127, 130, [valid_m, valid_k // 32]).astype(np.uint8)
x2_mx_gm = np.random.randint(127, 130, [valid_k // 32, valid_n]).astype(np.uint8)
x1 = np.zeros(x1_slice.shape, dtype=dst_type)
x2 = np.zeros(x2_slice.shape, dtype=dst_type)
x1_mx = 2**(x1_mx_gm.astype(np.float32) - 127)
x2_mx = 2**(x2_mx_gm.astype(np.float32) - 127)
for i in range(valid_k):
x1[:, i] = x1_slice[:, i] * x1_mx[:, i // 32]
x2[i, :] = x2_slice[i, :] * x2_mx[i // 32, :]
golden = np.matmul(x1.astype(dst_type), x2.astype(dst_type)).astype(dst_type)
c0_size = 2
x2_mx_gm = x2_mx_gm.transpose()
x1_mx_gm = x1_mx_gm.reshape((int(x1_mx_gm.shape[0] // 16), 16,
int(x1_mx_gm.shape[1] // c0_size), c0_size)).transpose(2, 0, 1, 3)
x1_mx_gm = x1_mx_gm.reshape(x1_mx_gm.shape[0] * x1_mx_gm.shape[1], x1_mx_gm.shape[2] * x1_mx_gm.shape[3])
x2_mx_gm = x2_mx_gm.reshape((int(x2_mx_gm.shape[0] // 16), 16,
int(x2_mx_gm.shape[1] // c0_size), c0_size)).transpose(2, 0, 1, 3)
x2_mx_gm = x2_mx_gm.reshape(x2_mx_gm.shape[0] * x2_mx_gm.shape[1], x2_mx_gm.shape[2] * x2_mx_gm.shape[3])
x1_gm_row = x1_gm.shape[0]
x1_gm_col = x1_gm.shape[1]
x1_gm = x1_gm.flatten()
x1_gm_high = x1_gm[::2].view(np.uint8)
x1_gm_low = x1_gm[1::2].view(np.uint8)
x1_gm = (((x1_gm_low & 0x0F) << 4) | (x1_gm_high & 0x0F)).reshape(x1_gm_row, x1_gm_col // 2)
x2_gm_row = x2_gm.shape[0]
x2_gm_col = x2_gm.shape[1]
x2_gm = x2_gm.flatten()
x2_gm_high = x2_gm[::2].view(np.uint8)
x2_gm_low = x2_gm[1::2].view(np.uint8)
x2_gm = (((x2_gm_low & 0x0F) << 4) | (x2_gm_high & 0x0F)).reshape(x2_gm_row, x2_gm_col // 2)
x1_mx_gm.tofile("./x1_mx_gm.bin")
x2_mx_gm.tofile("./x2_mx_gm.bin")
return x1_gm, x2_gm, golden
def create_padded_tensors(
x1_gm, x2_gm, m, n, k, target_m, target_n, target_k, src_type = np.int8,
rand_range_right = (1,5),
rand_range_down = (1,5),
rand_range_corner = (1,5)):
assert target_m >= m, f"target_m ({target_m}) mast be >= m ({m})"
assert target_n >= n, f"target_n ({target_n}) mast be >= n ({n})"
assert target_k >= k, f"target_k ({target_k}) mast be >= k ({k})"
x1_gm_padded = np.zeros((target_m, target_k), dtype=np.int32).astype(src_type)
x1_gm_padded[:m, :k] = x1_gm
right_fill = np.random.randint(rand_range_right[0], rand_range_right[1],
size=(m, target_k - k), dtype=np.int32).astype(src_type)
x1_gm_padded[:m, k:target_k] = right_fill
x1_gm_padded[m:target_m, :k] = 0
corner_fill = np.random.randint(rand_range_corner[0], rand_range_corner[1],
size=(target_m - m, target_k - k), dtype=np.int32).astype(src_type)
x1_gm_padded[m:target_m, k:target_k] = corner_fill
x2_gm_padded = np.zeros((target_k, target_n), dtype=np.int32).astype(src_type)
x2_gm_padded[:k, :n] = x2_gm
down_fill = np.random.randint(rand_range_down[0], rand_range_down[1],
size=(target_k - k, n), dtype=np.int32).astype(src_type)
x2_gm_padded[k:target_k, :n] = down_fill
x2_gm_padded[:k, n:target_n] = 0
corner_fill2 = np.random.randint(rand_range_corner[0], rand_range_corner[1],
size=(target_k - k, target_n - n), dtype=np.int32).astype(src_type)
x2_gm_padded[k:target_k, n:target_n] = corner_fill2
return x1_gm_padded, x2_gm_padded
def gen_golden_data(case_name, param):
src_type = param.atype
dst_type = param.ctype
m, k, n, start_m, start_k, start_n, is_bias, is_atrans, is_btrans, target_m, target_k, target_n = \
param.m, param.k, param.n, param.start_m, param.start_k, param.start_n, False, param.is_atrans, \
param.is_btrans, param.target_m, param.target_k, param.target_n
if(src_type == float4_e2m1 or src_type == float4_e1m2):
x1_gm = np.random.uniform(1, 6, [m, k]).astype(src_type)
x2_gm = np.random.uniform(1, 6, [k, n]).astype(src_type)
else:
x1_gm = np.random.randint(1, 5, [m, k]).astype(src_type)
x2_gm = np.random.randint(1, 5, [k, n]).astype(src_type)
x1_slice = x1_gm[start_m:, start_k:]
x2_slice = x2_gm[start_k:, start_n:]
golden = np.matmul(x1_slice.astype(dst_type), x2_slice.astype(dst_type)).astype(dst_type)
if target_m > 0 or target_n > 0 or target_k > 0:
target_m = target_m if target_m > 0 else m
target_n = target_n if target_n > 0 else n
target_k = target_k if target_k > 0 else k
x1_gm, x2_gm = create_padded_tensors(x1_gm, x2_gm, m, n, k, target_m, target_n, target_k, src_type, \
rand_range_right=(1,5), rand_range_down=(1,5), rand_range_corner=(1,5))
if (param.atype == np.uint8):
golden = get_hif8_golden(x1_gm, x2_gm, start_m, start_k, start_n, dst_type)
if is_atrans:
x1_gm = x1_gm.transpose()
if not is_btrans:
x2_gm = x2_gm.transpose()
if (src_type == float4_e2m1 or src_type == float4_e1m2):
valid_m = m - start_m
valid_k = k - start_k
valid_n = n - start_n
x1_gm, x2_gm, golden = data_processing_b4(valid_m, valid_k, valid_n, x1_gm, x2_gm, x1_slice, x2_slice, dst_type)
x1_gm.tofile("./x1_gm.bin")
x2_gm.tofile("./x2_gm.bin")
golden.tofile("./golden.bin")
class textractParams:
def __init__(self, atype, btype, ctype, m, k, n, start_m, start_k, start_n, \
is_atrans=0, is_btrans=0, target_m = 0, target_k = 0, target_n = 0):
self.atype = atype
self.btype = btype
self.ctype = ctype
self.m = m
self.k = k
self.n = n
self.start_m = start_m
self.start_k = start_k
self.start_n = start_n
self.is_atrans = is_atrans
self.is_btrans = is_btrans
self.target_m = target_m
self.target_k = target_k
self.target_n = target_n
if __name__ == "__main__":
case_name_list = [
"TEXTRACTTest.case1",
"TEXTRACTTest.case2",
"TEXTRACTTest.case3",
"TEXTRACTTest.case4",
"TEXTRACTTest.case5",
"TEXTRACTTest.case6",
"TEXTRACTTest.case7",
"TEXTRACTTest.case8",
"TEXTRACTTest.case9",
"TEXTRACTTest.case10",
"TEXTRACTTest.case11",
"TEXTRACTTest.case12",
"TEXTRACTTest.case13",
"TEXTRACTTest.case14",
"TEXTRACTTest.case15",
"TEXTRACTTest.case16",
"TEXTRACTTest.case17",
"TEXTRACTTest.case18",
"TEXTRACTTest.case19",
"TEXTRACTTest.case20",
"TMOVTest.case1",
"TMOVTest.case2",
"TMOVTest.case3",
"TMOVTest.case4",
"TMOVTest.case5",
"TMOVTest.case6",
"TMOVTest.case7",
"TMOVTest.case8",
"TMOVTest.case9",
"TMOVTest.case10",
"TMOVTest.case11",
"TMOVTest.case12",
"TMOVTest.case13",
"TMOVTest.case14",
"TMOVTest.case15",
"TMOVTest.case16",
"TMOVTest.case17",
]
case_params_list = [
textractParams(np.float16, np.float16, np.float32, 32, 96, 64, 0, 0, 0, 0, 0),
textractParams(np.float32, np.float32, np.float32, 128, 48, 64, 0, 0, 0, 0, 0),
textractParams(np.int8, np.int8, np.int32, 128, 128, 64, 0, 0, 0, 0, 0),
textractParams(np.float16, np.float16, np.float32, 64, 96, 64, 32, 16, 16, 0, 0),
textractParams(np.float32, np.float32, np.float32, 64, 128, 64, 32, 32, 16, 0, 0),
textractParams(np.int8, np.int8, np.int32, 128, 128, 64, 32, 64, 32, 0, 0),
textractParams(np.float16, np.float16, np.float32, 64, 128, 64, 0, 64, 0, 1, 1),
textractParams(np.float32, np.float32, np.float32, 64, 64, 128, 0, 0, 32, 1, 1),
textractParams(np.int8, np.int8, np.int32, 128, 64, 128, 32, 0, 0, 1, 1),
textractParams(bfloat16, bfloat16, np.float32, 64, 128, 64, 16, 0, 0, 1, 0),
textractParams(fp8_e4m3fn, fp8_e4m3fn, np.float32, 64, 128, 64, 0, 32, 0, 1, 0),
textractParams(fp8_e5m2, fp8_e5m2, np.float32, 64, 128, 64, 0, 0, 32, 0, 1),
textractParams(np.uint8, np.uint8, np.float32, 64, 128, 64, 0, 32, 0, 0, 1),
textractParams(np.int8, np.int8, np.int32, 64, 96, 32, 32, 0, 0, 1, 0),
textractParams(np.float16, np.float16, np.float32, 64, 48, 96, 16, 16, 0, 1, 0),
textractParams(np.float32, np.float32, np.float32, 32, 96, 48, 0, 32, 16, 0, 0),
textractParams(float4_e2m1, float4_e2m1, np.float32, 256, 128, 256, 128, 64, 128, 0, 0),
textractParams(float4_e1m2, float4_e1m2, np.float32, 256, 128, 256, 128, 64, 128, 0, 0),
textractParams(float4_e2m1, float4_e2m1, np.float32, 256, 128, 256, 128, 64, 128, 1, 1),
textractParams(float4_e1m2, float4_e1m2, np.float32, 256, 128, 256, 128, 64, 128, 1, 1),
textractParams(np.float16, np.float16, np.float32, 32, 96, 64, 0, 0, 0, 0, 0),
textractParams(np.float32, np.float32, np.float32, 128, 48, 64, 0, 0, 0, 0, 0),
textractParams(np.int8, np.int8, np.int32, 128, 128, 64, 0, 0, 0, 0, 0),
textractParams(bfloat16, bfloat16, np.float32, 64, 128, 64, 0, 0, 0, 1, 1),
textractParams(fp8_e4m3fn, fp8_e4m3fn, np.float32, 64, 96, 64, 0, 0, 0, 1, 0),
textractParams(fp8_e5m2, fp8_e5m2, np.float32, 64, 128, 64, 0, 0, 0, 0, 1),
textractParams(np.uint8, np.uint8, np.float32, 128, 128, 64, 0, 0, 0, 1, 0),
textractParams(np.int8, np.int8, np.int32, 64, 96, 64, 0, 0, 0, 1, 1),
textractParams(np.float16, np.float16, np.float32, 64, 128, 64, 0, 0, 0, 1, 0),
textractParams(np.float32, np.float32, np.float32, 64, 128, 64, 0, 0, 0, 0, 1),
textractParams(np.int8, np.int8, np.int32, 65, 40, 66, 0, 0, 0, 1, 1, 96, 64, 96),
textractParams(np.float16, np.float16, np.float32, 65, 40, 66, 0, 0, 0, 1, 1, 80, 48, 80),
textractParams(np.float32, np.float32, np.float32, 65, 40, 66, 0, 0, 0, 1, 1, 80, 48, 80),
textractParams(float4_e2m1, float4_e2m1, np.float32, 128, 64, 128, 0, 0, 0, 0, 0),
textractParams(float4_e1m2, float4_e1m2, np.float32, 128, 64, 128, 0, 0, 0, 0, 0),
textractParams(float4_e2m1, float4_e2m1, np.float32, 128, 64, 128, 0, 0, 0, 1, 1),
textractParams(float4_e1m2, float4_e1m2, np.float32, 128, 64, 128, 0, 0, 0, 1, 1),
]
for i, case_name in enumerate(case_name_list):
if not os.path.exists(case_name):
os.makedirs(case_name)
original_dir = os.getcwd()
os.chdir(case_name)
gen_golden_data(case_name, case_params_list[i])
os.chdir(original_dir)