import os
import sys
import logging
import torch
import numpy as np
IS_TRANS_A = False
IS_TRANS_B = True
IS_BIAS = False
DATA_TYPE_STR = "int8_int32"
IS_OUTPUT_TXT = False
class SparseMatmulGenData():
def __init__(self, m, n, k, b):
self.m = m
self.n = n
self.k = k
self.b = b
self.is_trans_a = IS_TRANS_A
self.is_trans_b = IS_TRANS_B
self.is_bias = IS_BIAS
self.data_type_str = DATA_TYPE_STR
@staticmethod
def generate_sparse_matrix_b(shape, dtype=np.int8):
"""生成一个指定形状的稀疏矩阵B,每行的每4个元素块至少包含2个零"""
n, k = shape
b = np.zeros((n, k), dtype=dtype)
for row in range(n):
for i in range(0, k, 4):
block = np.zeros(4, dtype=dtype)
non_zero_positions = np.random.choice(4, 2, replace=False)
block[non_zero_positions[0]] = np.random.randint(1, 10, dtype=dtype)
block[non_zero_positions[1]] = np.random.randint(1, 10, dtype=dtype)
b[row, i:i + 4] = block
return b
@staticmethod
def densify_and_generate_index(b):
"""稠密化稀疏矩阵B,并生成索引矩阵"""
n, k = b.shape
dense_b = np.zeros((n, k // 2), dtype=b.dtype)
index_matrix = np.zeros((n, k // 8), dtype=np.uint8)
index_mask_matrix = np.zeros((n, k // 2), dtype=np.uint32)
for row in range(n):
dense_row = []
index_row = []
index_mask_row = []
for i in range(0, k, 4):
block = b[row, i:i + 4]
nonzero_positions = [j for j in range(4) if block[j] != 0]
if len(nonzero_positions) == 0:
index_1 = 0
index_2 = 0
index_mask_row.extend([i, i])
elif len(nonzero_positions) == 1:
index_1 = nonzero_positions[0] if nonzero_positions[0] < 3 else 0
index_2 = 0 if nonzero_positions[0] < 3 else 2
index_mask_row.extend([nonzero_positions[0] + i, i])
else:
index_1 = nonzero_positions[0]
index_2 = nonzero_positions[1] - 1
index_mask_row.extend([nonzero_positions[0] + i, nonzero_positions[1] + i])
dense_block = [block[pos] for pos in nonzero_positions[:2]]
if len(dense_block) < 2:
dense_block += [0] * (2 - len(dense_block))
dense_row.extend(dense_block)
index_row.extend([index_1, index_2])
index_bytes = []
for j in range(0, len(index_row), 4):
indices = index_row[j:j + 4]
int8_value = sum((index << (2 * bit_pos)) for bit_pos, index in enumerate(indices))
index_bytes.append(int8_value)
dense_b[row, :] = dense_row
index_matrix[row, :] = index_bytes
index_mask_matrix[row, :] = index_mask_row
return dense_b, index_matrix, index_mask_matrix
@staticmethod
def gen_sparse_golden(a, dense_b, index_mask_matrix):
result_type = np.int32
m = a.shape[0]
n = dense_b.shape[0]
c = np.zeros((m, n), dtype=result_type)
for r in range(n):
selected_columns = index_mask_matrix[r]
a_selected = a[:, selected_columns]
c[:, r] = np.dot(a_selected.astype(result_type), dense_b[r].astype(result_type)).astype(result_type)
return c
@staticmethod
def index_nd_to_nz(index_matrix):
n = index_matrix.shape[0]
k = index_matrix.shape[1]
ceil_n = int(np.ceil(n / 16) * 16)
ceil_k = int(np.ceil(k / 8) * 8)
index_matrix_nz = np.zeros((ceil_n, ceil_k), dtype=np.uint8)
index_matrix_nz[:n, :k] = index_matrix
new_shape = (ceil_n // 16, 16, ceil_k // 8, 8)
index_matrix_nz = index_matrix_nz.reshape(new_shape)
index_matrix_nz = index_matrix_nz.transpose(2, 0, 1, 3)
return index_matrix_nz
def check_params(self):
if self.data_type_str != "int8_int32":
logging.info("[ERROR] can't support data type %s" % (self.data_type_str))
return -1
if self.k % 8 != 0:
logging.info("[ERROR] sparse k %d must be multiple of 8" % (self.k))
return -1
if self.b != 1:
logging.info("[ERROR] sparse batch %d must be 1" % (self.b))
return -1
return 0
def gen_golden_data(self, work_dir):
if self.check_params() != 0:
return -1
a_gm = np.random.randint(-10, 10, [self.m, self.k], dtype=np.int8)
b_sparse = self.generate_sparse_matrix_b((self.n, self.k)).astype(np.int8)
b_gm, index_matrix, index_mask_matrix = self.densify_and_generate_index(b_sparse)
index_gm = self.index_nd_to_nz(index_matrix)
if self.is_bias:
bias_gm = np.random.randint(-10, 10, (self.b, 1, self.n), dtype=np.int32)
c_gm = self.gen_sparse_golden(a_gm, b_gm, index_mask_matrix)
if self.is_bias:
c_gm = c_gm + bias_gm
if self.is_trans_a:
a_gm = a_gm.T
a_gm.tofile(work_dir + "/input/x1_gm.bin")
b_gm.tofile(work_dir + "/input/x2_gm.bin")
index_gm.tofile(work_dir + "/input/index_gm.bin")
c_gm.tofile(work_dir + "/output/golden.bin")
if self.is_bias:
bias_gm.tofile(work_dir + "/input/bias_gm.bin")
if IS_OUTPUT_TXT:
np.savetxt(work_dir + "/input/x1_gm.txt", a_gm, fmt='%d', newline='\n')
np.savetxt(work_dir + "/input/x2_gm.txt", b_gm, fmt='%d', newline='\n')
np.savetxt(work_dir + "/input/index_gm.txt", index_gm.flatten(), fmt='%d', newline='\n')
np.savetxt(work_dir + "/output/golden.txt", c_gm, fmt='%d', newline='\n')
if self.is_bias:
np.savetxt(work_dir + "/input/bias_gm.txt", bias_gm.flatten(), fmt='%d', newline='\n')
return 0
def gen_fake_golden_data(self, work_dir):
if self.check_params() != 0:
return -1
data_type_bytes_ab = 1
data_type_bytes_c = 4
file_byte = self.m * self.k * data_type_bytes_ab
with open(work_dir + "/input/x1_gm.bin", 'wb') as file:
file.truncate(file_byte)
file_byte = self.k * self.n * data_type_bytes_ab // 2
with open(work_dir + "/input/x2_gm.bin", 'wb') as file:
file.truncate(file_byte)
file_byte = self.k * self.n * data_type_bytes_ab // 8
with open(work_dir + "/input/index_gm.bin", 'wb') as file:
file.truncate(file_byte)
if self.is_bias:
file_byte = 1 * self.n * data_type_bytes_c
with open(work_dir + "/input/bias_gm.bin", 'wb') as file:
file.truncate(file_byte)
return 0
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
shape_m = int(sys.argv[1])
shape_n = int(sys.argv[2])
shape_k = int(sys.argv[3])
shape_b = 1
matmul_gen_data = SparseMatmulGenData(shape_m, shape_n, shape_k, shape_b)
data_dir = os.path.join(current_dir, "input")
out_dir = os.path.join(current_dir, "output")
os.makedirs(data_dir, exist_ok=True)
os.makedirs(out_dir, exist_ok=True)
matmul_gen_data.gen_golden_data(current_dir)