import sys
import numpy as np
import struct
def float32_to_bfloat16_bytes(f):
"""Convert float32 to bfloat16 bytes representation."""
f_bytes = struct.pack('>f', f)
bf16_bytes = f_bytes[:2]
return bf16_bytes
def bfloat16_bytes_to_float32(bf16_bytes):
"""Convert bfloat16 bytes to float32."""
f_bytes = bf16_bytes + b'\x00\x00'
return struct.unpack('>f', f_bytes)[0]
def gen_data(shape_str, dtype_str, scalar_value):
"""Generate input and golden data for Adds operator."""
shape = eval(shape_str)
np.random.seed(42)
if dtype_str == 'float32':
dtype = np.float32
input_data = np.random.uniform(-2, 2, size=shape).astype(dtype)
output_data = (input_data + scalar_value).astype(dtype)
input_data.tofile('./input.bin')
output_data.tofile('./golden.bin')
elif dtype_str == 'float16':
dtype = np.float16
input_data = np.random.uniform(-2, 2, size=shape).astype(dtype)
output_data = (input_data + scalar_value).astype(dtype)
input_data.tofile('./input.bin')
output_data.tofile('./golden.bin')
elif dtype_str == 'bfloat16':
input_fp32 = np.random.uniform(-2, 2, size=shape).astype(np.float32)
output_fp32 = (input_fp32 + scalar_value).astype(np.float32)
with open('./input.bin', 'wb') as f:
for val in input_fp32.flatten():
f.write(float32_to_bfloat16_bytes(val))
with open('./golden.bin', 'wb') as f:
for val in output_fp32.flatten():
f.write(float32_to_bfloat16_bytes(val))
elif dtype_str == 'int16':
dtype = np.int16
input_data = np.random.randint(-100, 100, size=shape, dtype=dtype)
output_data = np.clip(input_data + int(scalar_value), -32768, 32767).astype(dtype)
input_data.tofile('./input.bin')
output_data.tofile('./golden.bin')
elif dtype_str == 'int32':
dtype = np.int32
input_data = np.random.randint(-1000, 1000, size=shape, dtype=dtype)
output_data = (input_data + int(scalar_value)).astype(dtype)
input_data.tofile('./input.bin')
output_data.tofile('./golden.bin')
elif dtype_str == 'int64':
dtype = np.int64
input_data = np.random.randint(-10000, 10000, size=shape, dtype=dtype)
output_data = (input_data + int(scalar_value)).astype(dtype)
input_data.tofile('./input.bin')
output_data.tofile('./golden.bin')
print(f"Generated data: shape={shape}, dtype={dtype_str}, scalar={scalar_value}")
if __name__ == "__main__":
if len(sys.argv) < 4:
print("Usage: python3 gen_data.py <shape> <dtype> <scalar>")
sys.exit(1)
shape_str = sys.argv[1]
dtype_str = sys.argv[2]
scalar_value = float(sys.argv[3])
gen_data(shape_str, dtype_str, scalar_value)