import os
import sys
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"
import numpy as np
import torch
from ml_dtypes import bfloat16
DATA_TYPE = torch.bfloat16
def write_artifacts(base_dir, a_data, b_data, out):
input_dir = os.path.join(base_dir, "input")
output_dir = os.path.join(base_dir, "output")
os.makedirs(input_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
a_data.view(torch.uint16).numpy().tofile(os.path.join(input_dir, "input_a.bin"))
b_data.view(torch.uint16).numpy().tofile(os.path.join(input_dir, "input_b.bin"))
out.view(torch.uint16).numpy().tofile(os.path.join(output_dir, "cpu_output.bin"))
def gen_golden_data_simple(m, k, n):
M = m
K = k
N = n
a_ori = np.random.uniform(1, 8, (M, K)).astype(np.float32)
b_ori = np.random.uniform(1, 8, (K, N)).astype(np.float32)
a_cpu = torch.from_numpy(a_ori).to(DATA_TYPE)
b_cpu = torch.from_numpy(b_ori).to(DATA_TYPE)
out = torch.matmul(a_cpu, b_cpu).to(DATA_TYPE)
current_dir = os.getcwd()
write_artifacts(current_dir, a_cpu, b_cpu, out)
script_dir = os.path.dirname(os.path.abspath(__file__))
if os.path.normcase(os.path.abspath(script_dir)) != os.path.normcase(os.path.abspath(current_dir)):
write_artifacts(script_dir, a_cpu, b_cpu, out)
print("Data generated successfully!")
if __name__ == "__main__":
if len(sys.argv) != 4 and len(sys.argv) != 6:
print("Usage: python3 gen_data.py m k n")
print("Or")
print("Usage: python3 gen_data.py m k n transA transB")
print("Example1: python3 gen_data.py 100 50 200")
print("Example2: python3 gen_data.py 100 50 200 false true")
sys.exit(1)
m = int(sys.argv[1])
k = int(sys.argv[2])
n = int(sys.argv[3])
gen_golden_data_simple(m, k, n)