import os
import sys
import torch
import numpy as np
import pandas as pd
from enum import Enum
CONST_16 = 16
RECORD_COUNT = 10
DATA_RANGE = (-1.0, 1.0)
WORKSPACE = os.getcwd()
os.environ["WORKSPACE"] = WORKSPACE
os.environ["ASCEND_GLOBAL_LOG_LEVEL"] = "3"
os.environ["ASCEND_SLOG_PRINT_TO_STDOUT"] = "0"
class CubeFormat(Enum):
ND = 0
NZ = 1
ZN = 2
ZZ = 3
NN = 4
VECTOR = 5
def __repr__(self) -> str:
return self.__name__
class OpParam:
def __init__(self) -> None:
self.b = 0
self.m = 0
self.k = 0
self.n = 0
self.transA = False
self.transB = False
self.enBias = False
self.enScale = False
self.enResidual = False
self.layoutA = CubeFormat.ND
self.layoutB = CubeFormat.ND
self.layoutC = CubeFormat.ND
def __str__(self) -> str:
return f"Shape: ({self.b}, {self.m}, {self.k}, {self.n}) \n" + \
f"Transpose: A {self.transA}, B {self.transB} \n" + \
f"(De)Quant: Bias {self.enBias}, Scale {self.enScale}, Residual {self.enResidual} \n" + \
f"Layout: layoutA {self.layoutA}, layoutB {self.layoutB}, layoutC {self.layoutC}"
def gen_rand(msize, nsize, low, high):
return low + (high - low) * torch.rand((msize, nsize),dtype=torch.float32)
def gen_data_int8(row, col):
data = np.random.randint(-8, 8, size=(row, col), dtype=np.int8)
return data
def gen_data_int4(row, col, trans):
data_int8_origin = np.random.randint(-8, 8, size=(row, col), dtype=np.int8)
if trans:
data_int8_origin = data_int8_origin.T
data_int8 = data_int8_origin
if row % 2 != 0:
zero_row = np.zeros((col, 1), dtype=np.int8)
data_int8 = np.hstack((data_int8_origin, zero_row))
quantized = data_int8.reshape(-1, 2)
high_quantized = (quantized[:, 0] & 0x0F)
low_quantized = (quantized[:, 1] & 0x0F) << 4
data_int4 = low_quantized | high_quantized
data_int4_array = np.array(data_int4, dtype=np.int8)
return data_int8_origin.T, data_int4_array
else:
data_int8 = data_int8_origin
if col % 2 != 0:
zero_column = np.zeros((row, 1), dtype=np.int8)
data_int8 = np.hstack((data_int8_origin, zero_column))
quantized = data_int8.reshape(-1, 2)
high_quantized = (quantized[:, 0] & 0x0F)
low_quantized = (quantized[:, 1] & 0x0F) << 4
data_int4 = low_quantized | high_quantized
data_int4_array = np.array(data_int4, dtype=np.int8)
return data_int8_origin, data_int4_array
def gen_testcase(path: str, param: OpParam) -> None:
bsize, msize, ksize, nsize = param.b, param.m, param.k, param.n
transA, transB = param.transA, param.transB
a_int8 = gen_data_int8(msize, ksize)
b_int8, b_int4 = gen_data_int4(ksize, nsize, transB)
b_int4.tofile(os.path.join(path, "inputB.dat"))
c_int32 = np.dot(a_int8.astype(np.float32), b_int8.astype(np.float32))
c_int32 = np.float32(1.5) * c_int32
if transA:
a_int8 = a_int8.T
a_int8.tofile(os.path.join(path, "inputA.dat"))
c_float = c_int32.astype(np.float32)
c_half = c_int32.astype(np.float16)
c_half.tofile(os.path.join(path, "inputC.dat"))
c_float.tofile(os.path.join(path, "expected.dat"))
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
param = OpParam()
param.b = 1
param.m = int(sys.argv[1])
param.n = int(sys.argv[2])
param.k = int(sys.argv[3])
param.transA = 0
param.transB = 0
data_dir = os.path.join(current_dir, "data")
os.makedirs(data_dir, exist_ok=True)
gen_testcase(data_dir, param)