import os
import numpy as np
import torch
import torch_npu
from utils import CommType, DataType, tensor_to_file
def gen_random_data(size, dtype):
if dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32:
return torch.randn(size=size, dtype=dtype)
elif dtype == torch.int8:
return torch.randint(-16, 16, size=size, dtype=dtype)
else:
print(f"Invalid dtype: {dtype}.")
raise ValueError(f"Invalid dtype: {dtype}")
def gen_golden_data():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('comm_type', type=CommType.from_str,
choices=[CommType.MATMUL_ALLREDUCE,
CommType.ALLGATHER_MATMUL,
CommType.MATMUL_REDUCE_SCATTER,
CommType.ALLGATHER_MATMUL_PADDING,
CommType.MATMUL_REDUCE_SCATTER_PADDING,
CommType.ALLGATHER_MATMUL_WITH_GATHER_RESULT])
parser.add_argument('out_dtype', type=DataType.from_str, choices=[DataType.FLOAT16, DataType.BF16])
parser.add_argument('pe_size', type=int)
parser.add_argument('m', type=int)
parser.add_argument('n', type=int)
parser.add_argument('k', type=int)
parser.add_argument('transA', type=int)
parser.add_argument('transB', type=int)
parser.add_argument('data_dir', type=str,
help='Directory to save the data files',
default="./out")
args = parser.parse_args()
m, n, k = args.m, args.n, args.k
data_dir = os.path.abspath(args.data_dir)
os.makedirs(data_dir, exist_ok=True)
b_all_pe = gen_random_data([k, n], dtype=args.out_dtype.torch_type)
matrix_a_list = []
matrix_c_list_fp32 = []
matrix_c_list_fp16_npu = []
for i in range(args.pe_size):
a_gm = gen_random_data([m, k], dtype=args.out_dtype.torch_type)
matrix_a_list.append(a_gm)
b_gm = b_all_pe
a_np = a_gm.to(torch.float32).numpy()
b_np = b_gm.to(torch.float32).numpy()
matrix_c_fp32 = np.matmul(a_np, b_np)
matrix_c_list_fp32.append(matrix_c_fp32)
a_torch = a_gm.npu()
b_torch = b_gm.npu()
matrix_c_fp16_npu = torch.matmul(a_torch, b_torch)
matrix_c_list_fp16_npu.append(matrix_c_fp16_npu)
if args.transA:
a_gm = a_gm.transpose(0, 1).contiguous()
if args.transB:
b_gm = b_gm.transpose(0, 1).contiguous()
a_gm_path = os.path.join(data_dir, f"pe_{i}_a.bin")
b_gm_path = os.path.join(data_dir, f"pe_{i}_b.bin")
tensor_to_file(a_gm, a_gm_path)
tensor_to_file(b_gm, b_gm_path)
golden = None
torch_output = None
if (args.comm_type in
[CommType.ALLGATHER_MATMUL, CommType.ALLGATHER_MATMUL_PADDING, CommType.ALLGATHER_MATMUL_WITH_GATHER_RESULT]):
golden = np.concatenate(matrix_c_list_fp32, axis=0)
torch_output = torch.cat([t.cpu() for t in matrix_c_list_fp16_npu], dim=0)
else:
golden = np.zeros_like(matrix_c_list_fp32[0])
torch_output = torch.zeros_like(matrix_c_list_fp16_npu[0].cpu())
for i in range(args.pe_size):
golden += matrix_c_list_fp32[i]
torch_output += matrix_c_list_fp16_npu[i].cpu()
tensor_to_file(torch_output, os.path.join(data_dir, "torch_output.bin"))
golden.tofile(os.path.join(data_dir, "golden.bin"))
if args.comm_type == CommType.ALLGATHER_MATMUL_WITH_GATHER_RESULT:
tensor_to_file(torch.cat(matrix_a_list, dim=0).to(torch.float32), os.path.join(data_dir, "gather_a.bin"))
if __name__ == '__main__':
gen_golden_data()