import os
import struct
import math
from dataclasses import dataclass
from typing import Optional
import numpy as np
from ml_dtypes import float8_e4m3fn, bfloat16
np.random.seed(19)
def float_to_hex(f):
return hex(struct.unpack("<I", struct.pack("<f", f))[0])
def scale_data(data_fp32, data_scaling, group_size=32):
data_fp32_reshaped = data_fp32.reshape(-1, group_size)
scaled_data = data_fp32_reshaped * data_scaling
max_e4m3 = 448
data_scale_clipped = np.clip(scaled_data, -max_e4m3, max_e4m3)
data_casted = data_scale_clipped.astype(float8_e4m3fn)
return data_casted
def scale_data_bf16(data_bf16, data_scaling, group_size=32):
data_bf16_reshaped = data_bf16.reshape(-1, group_size)
scaled_data = (data_bf16_reshaped * data_scaling).astype(bfloat16)
max_e4m3 = bfloat16(448)
data_scale_clipped = np.clip(scaled_data, bfloat16(-448), max_e4m3)
data_casted = data_scale_clipped.astype(float8_e4m3fn)
return data_casted
def scale_data_fp16(data_fp16, data_scaling, group_size=32):
data_fp16_reshaped = data_fp16.reshape(-1, group_size)
scaled_data = (data_fp16_reshaped * data_scaling).astype(np.float16)
max_e4m3 = np.float16(448)
data_scale_clipped = np.clip(scaled_data, np.float16(-448), max_e4m3)
data_casted = data_scale_clipped.astype(float8_e4m3fn)
return data_casted
def scale_data_fp16_nv(data_fp16, data_scaling, group_size=32):
data_fp16_reshaped = data_fp16.reshape(-1, group_size).astype(np.float32)
scaling_fp32 = data_scaling.astype(np.float32)
scaled_data = data_fp16_reshaped * scaling_fp32
data_scale_clipped = np.clip(scaled_data, -448.0, 448.0).astype(np.float32)
return data_scale_clipped.astype(float8_e4m3fn)
def get_group_max_last_dim(data: np.ndarray, group_size: int = 32):
data_abs = np.abs(data)
data_grouped = data_abs.reshape(-1, group_size)
group_max = np.max(data_grouped, axis=1)
return group_max
def fp32_to_fp8_element(data_abs_max, emax):
data_abs_max = np.uint32(np.frombuffer(np.float32(data_abs_max).tobytes(), dtype=np.uint32)[0])
exponent_b32 = int((data_abs_max & np.uint32(0x7F800000)) >> np.uint32(23))
mantissa_b32 = int(data_abs_max & np.uint32(0x007FFFFF))
if exponent_b32 == 0xFF and mantissa_b32 != 0:
return 0xFF, np.uint32(0x7FC00000).view(np.float32)
if exponent_b32 <= emax:
return 0x00, np.uint32(0x7F000000).view(np.float32)
e8m0 = exponent_b32 - emax
scale_exp = 254 - e8m0
scaling = scale_exp << 23
scaling = np.uint32(scaling).view(np.float32)
if scaling == 0.0:
scaling = np.pow(2.0, -127)
return e8m0, scaling
def nd2nz_mxfp8(data_fp8, tile_m, tile_n):
data_fp8_reshaped = data_fp8.reshape(int(tile_m), int(math.ceil(tile_n / 32)), 32)
data_fp8_nz = np.transpose(data_fp8_reshaped, [1, 0, 2])
return data_fp8_nz
def nd2zz_e8m0(e8m0, tile_m, tile_n_div_32):
index_array = np.arange(e8m0.size).reshape(e8m0.shape)
index_reshaped = index_array.reshape(int(math.ceil(tile_m / 16)), 16, int(math.ceil(tile_n_div_32 / 2)), 2)
index_zz = (np.transpose(index_reshaped, [0, 2, 1, 3])).flatten()
index_zz_b16 = index_zz // 2
index_zz_b16_selected = index_zz_b16[::2].astype(np.uint16)
index_zz_b16_selected.tofile("index_vselr_b16.bin")
e8m0_reshaped = e8m0.reshape(int(math.ceil(tile_m / 16)), 16, int(math.ceil(tile_n_div_32 / 2)), 2)
e8m0_zz = np.transpose(e8m0_reshaped, [0, 2, 1, 3]).astype(np.uint8)
return e8m0_zz
def fp32_maxes_to_fp8(data_abs_max, emax=8):
e8m0s = []
scalings = []
data_abs_max_list = data_abs_max.reshape(-1).tolist()
for itm in data_abs_max_list:
e8m0, scaling = fp32_to_fp8_element(itm, emax=emax)
e8m0s.append(e8m0)
scalings.append(scaling)
e8m0s = np.array(e8m0s).astype(np.uint8)
scalings = np.array(scalings).reshape(-1, 1).astype(np.float32)
return e8m0s, scalings
def nv_fp32_to_mx_element(data_abs_max, max_value):
if np.float32(data_abs_max) == np.float32(0.0):
return 0x00, np.uint32(0x7F000000).view(np.float32)
descale = np.float32(data_abs_max) * np.float32(1.0 / max_value)
bits = np.uint32(np.frombuffer(descale.tobytes(), dtype=np.uint32)[0])
exponent = int((bits & np.uint32(0x7F800000)) >> np.uint32(23))
mantissa = int(bits & np.uint32(0x007FFFFF))
if exponent == 0xFF:
if mantissa == 0:
return 0xFE, np.float32(2.0**-127)
return 0xFF, np.float32(np.nan)
round_up = mantissa > 0 and exponent != 0xFE and not (exponent == 0 and mantissa <= 0x00400000)
e8m0 = exponent + (1 if round_up else 0)
if e8m0 == 0:
return e8m0, np.uint32(0x7F000000).view(np.float32)
if e8m0 == 0xFE:
return e8m0, np.float32(2.0**-127)
scaling_bits = np.uint32((254 - e8m0) << 23)
return e8m0, scaling_bits.view(np.float32)
def nv_fp32_to_fp8_element(data_abs_max):
return nv_fp32_to_mx_element(data_abs_max, max_value=448.0)
def nv_maxes_to_fp8(data_abs_max):
e8m0s = []
scalings = []
for itm in data_abs_max.reshape(-1).tolist():
e8m0, scaling = nv_fp32_to_fp8_element(itm)
e8m0s.append(e8m0)
scalings.append(scaling)
return np.array(e8m0s).astype(np.uint8), np.array(scalings).reshape(-1, 1).astype(np.float32)
def fp16_to_fp8_element(data_abs_max_fp16, emax):
del emax
return bf16_to_fp8_element(data_abs_max_fp16)
def bf16_to_fp8_element(data_abs_max_bf16):
data_u16 = np.uint16(np.frombuffer(np.float32(data_abs_max_bf16).tobytes(), dtype=np.uint32)[0] >> 16)
exponent_bf16 = int(data_u16 & 0x7F80)
mantissa_bf16 = int(data_u16 & 0x007F)
if exponent_bf16 == 0x7F80 and mantissa_bf16 != 0:
return 0xFF, np.uint16(0x7F81)
exponent_bf16 = max(exponent_bf16, 0x0400)
shared_exp_bits = exponent_bf16 - 0x0400
e8m0 = (shared_exp_bits >> 7) & 0xFF
scale_bits = np.uint16(0x7F00 - shared_exp_bits)
return e8m0, scale_bits
def fp16_maxes_to_fp8(data_abs_max, emax=8):
e8m0s = []
scaling_bits = []
data_abs_max_list = data_abs_max.reshape(-1).tolist()
for itm in data_abs_max_list:
e8m0, scaling = fp16_to_fp8_element(itm, emax=emax)
e8m0s.append(e8m0)
scaling_bits.append(scaling)
e8m0s = np.array(e8m0s).astype(np.uint8)
scalings = bf16_bits_to_float32(np.array(scaling_bits, dtype=np.uint16)).reshape(-1, 1)
return e8m0s, scalings.astype(np.float32)
def bf16_maxes_to_fp8(data_abs_max):
return fp16_maxes_to_fp8(data_abs_max, emax=8)
def float32_to_bf16_trunc(data):
data_fp32 = np.asarray(data, dtype=np.float32)
bits = data_fp32.view(np.uint32) & np.uint32(0xFFFF0000)
return bits.view(np.float32)
def bf16_bits_to_float32(bits):
return np.uint32(np.uint16(bits).astype(np.uint32) << np.uint32(16)).view(np.float32)
def fp16_to_e2m1_element(data_abs_max_bf16):
data_u16 = np.uint16(np.frombuffer(np.float32(data_abs_max_bf16).tobytes(), dtype=np.uint32)[0] >> 16)
exponent_bf16 = int(data_u16 & 0x7F80)
mantissa_bf16 = int(data_u16 & 0x007F)
if exponent_bf16 == 0x7F80 and mantissa_bf16 != 0:
return 0xFF, np.uint16(0x7FC0)
exponent_bf16 = max(exponent_bf16, 0x0100)
shared_exp_bits = exponent_bf16 - 0x0100
e8m0 = (shared_exp_bits >> 7) & 0xFF
scale_bits = np.uint16(0x7F00 - shared_exp_bits)
return e8m0, scale_bits
def fp16_maxes_to_e2m1(data_abs_max):
e8m0s = []
scaling_bits = []
for itm in data_abs_max.reshape(-1).tolist():
e8m0, scaling = fp16_to_e2m1_element(itm)
e8m0s.append(e8m0)
scaling_bits.append(scaling)
scaling_bf16 = bf16_bits_to_float32(np.array(scaling_bits, dtype=np.uint16)).reshape(-1, 1)
return np.array(e8m0s).astype(np.uint8), scaling_bf16.astype(np.float32)
def nv_fp32_to_e2m1_element(data_abs_max):
return nv_fp32_to_mx_element(data_abs_max, max_value=6.0)
def nv_maxes_to_e2m1(data_abs_max):
e8m0s = []
scalings = []
for itm in data_abs_max.reshape(-1).tolist():
e8m0, scaling = nv_fp32_to_e2m1_element(itm)
e8m0s.append(e8m0)
scalings.append(scaling)
return np.array(e8m0s).astype(np.uint8), np.array(scalings).reshape(-1, 1).astype(np.float32)
def encode_e2m1_magic_scalar(value):
value = np.float32(value)
bits = np.frombuffer(value.tobytes(), dtype=np.uint32)[0]
sign = np.uint8((bits >> 28) & 0x8)
abs_value = np.float32(abs(value))
if np.isnan(abs_value):
return np.uint8(0x7)
if np.isinf(abs_value):
return np.uint8(sign | 0x7)
abs_bits = np.frombuffer(abs_value.tobytes(), dtype=np.uint32)[0]
biased_exp = int((abs_bits & np.uint32(0x7F800000)) >> 23)
biased_exp = min(max(biased_exp, 127), 129)
magic_bits = np.uint32((biased_exp + 22) << 23)
magic = magic_bits.view(np.float32)
q = np.frombuffer(np.float32(abs_value + magic).tobytes(), dtype=np.uint32)[0] - magic_bits
mag_code = min(int(q) + ((biased_exp - 127) << 1), 7)
return np.uint8(sign | mag_code)
def pack_fp4_e2m1(codes, rows, cols):
packed_cols = (cols + 1) // 2
packed = np.zeros((rows, packed_cols), dtype=np.uint8)
codes = codes.reshape(rows, cols)
for r in range(rows):
for c in range(cols):
if c % 2 == 0:
packed[r, c // 2] = (packed[r, c // 2] & 0xF0) | codes[r, c]
else:
packed[r, c // 2] = (packed[r, c // 2] & 0x0F) | (codes[r, c] << 4)
return packed
def quant_fp16_to_e2m1(src, scale_alg="ocp"):
src_fp32 = src.astype(np.float32)
if scale_alg == "nv":
data_abs = np.abs(src).astype(np.float16)
with np.errstate(invalid="ignore"):
group_max = np.max(data_abs.reshape(-1, 32), axis=1).astype(np.float32)
e8m0, scaling_bf16 = nv_maxes_to_e2m1(group_max)
else:
src_bf16_for_max = float32_to_bf16_trunc(src_fp32)
data_abs = np.abs(src_bf16_for_max).astype(np.float32)
data_grouped = data_abs.reshape(-1, 32)
group_max_bf16 = np.max(data_grouped, axis=1)
e8m0, scaling_bf16 = fp16_maxes_to_e2m1(group_max_bf16)
scaled = (src_fp32.reshape(-1, 32) * scaling_bf16.astype(np.float32)).reshape(src.shape).astype(np.float32)
codes = np.vectorize(encode_e2m1_magic_scalar, otypes=[np.uint8])(scaled)
packed = pack_fp4_e2m1(codes, src.shape[0], src.shape[1])
e8m0.tofile("golden_e8m0.bin")
scaling_bf16.astype(np.float32).tofile("scaling_e2m1.bin")
packed.tofile("golden_fp4.bin")
return e8m0, scaling_bf16, packed
def quant_fp32_to_e4m3(src, mode="nd", scale_alg="ocp"):
group_max = get_group_max_last_dim(src, group_size=32)
if scale_alg == "nv":
e8m0, scaling = nv_maxes_to_fp8(group_max)
else:
e8m0, scaling = fp32_maxes_to_fp8(group_max, emax=8)
if mode == "nz":
tile_m = src.shape[0]
tile_n = src.shape[1]
data_fp8 = scale_data(src, scaling, group_size=32)
data_fp8 = nd2nz_mxfp8(data_fp8, tile_m, tile_n)
e8m0 = nd2zz_e8m0(e8m0, tile_m, int(tile_n / 32))
else:
data_fp8 = scale_data(src, scaling, group_size=32)
e8m0.tofile("golden_e8m0.bin")
scaling.tofile("scaling_e4m3.bin")
data_fp8.tofile("golden_fp8.bin")
return e8m0, scaling, data_fp8, group_max
def quant_bf16_to_e4m3(src, mode="nd", scale_alg="ocp"):
data_abs = np.abs(src).astype(bfloat16)
data_grouped = data_abs.reshape(-1, 32)
with np.errstate(invalid="ignore"):
group_max_bf16 = np.max(data_grouped, axis=1)
group_max_fp32 = group_max_bf16.astype(np.float32)
if scale_alg == "nv":
e8m0, scaling_fp32 = nv_maxes_to_fp8(group_max_fp32)
else:
e8m0, scaling_fp32 = bf16_maxes_to_fp8(group_max_fp32)
scaling_bf16 = scaling_fp32.astype(bfloat16)
if mode == "nz":
tile_m = src.shape[0]
tile_n = src.shape[1]
data_fp8 = scale_data_bf16(src, scaling_bf16, group_size=32)
data_fp8 = nd2nz_mxfp8(data_fp8, tile_m, tile_n)
e8m0 = nd2zz_e8m0(e8m0, tile_m, int(tile_n / 32))
else:
data_fp8 = scale_data_bf16(src, scaling_bf16, group_size=32)
e8m0.tofile("golden_e8m0.bin")
scaling_fp32.tofile("scaling_e4m3.bin")
data_fp8.tofile("golden_fp8.bin")
return e8m0, scaling_bf16, data_fp8, group_max_bf16
def quant_fp16_to_e4m3(src, mode="nd", scale_alg="ocp"):
if scale_alg == "nv":
data_abs = np.abs(src).astype(np.float16)
data_grouped = data_abs.reshape(-1, 32)
group_max_for_scale = np.max(data_grouped, axis=1).astype(np.float32)
e8m0, scaling_fp32 = nv_maxes_to_fp8(group_max_for_scale)
scaling = scaling_fp32.astype(bfloat16)
else:
src_bf16_for_max = float32_to_bf16_trunc(src.astype(np.float32))
data_abs = np.abs(src_bf16_for_max).astype(np.float32)
data_grouped = data_abs.reshape(-1, 32)
group_max_fp16 = np.max(data_grouped, axis=1).astype(np.float32)
e8m0, scaling = fp16_maxes_to_fp8(group_max_fp16, emax=-104)
if mode == "nz":
tile_m = src.shape[0]
tile_n = src.shape[1]
data_fp8 = scale_data_fp16_nv(src, scaling, group_size=32)
data_fp8 = nd2nz_mxfp8(data_fp8, tile_m, tile_n)
e8m0 = nd2zz_e8m0(e8m0, tile_m, int(tile_n / 32))
else:
data_fp8 = scale_data_fp16_nv(src, scaling, group_size=32)
e8m0.tofile("golden_e8m0.bin")
scaling.astype(np.float32).tofile("scaling_e4m3.bin")
data_fp8.tofile("golden_fp8.bin")
return e8m0, scaling, data_fp8
MX_BOUNDARY_GROUP_SIZE = 32
MX_SCALE_ALG_ANY = "any"
MX_SCALE_ALG_OCP = "ocp"
MX_SCALE_ALG_NV = "nv"
MX_DST_MXFP8 = "mxfp8"
MX_SRC_FP16 = "fp16"
MX_SRC_BF16 = "bf16"
MX_SRC_FP32 = "fp32"
MXFP4_LOW_MAGNITUDE_VALUES = (1.0, -1.0, 0.75, -0.75, 0.5, -0.5, 0.375, -0.375, 0.25, -0.25, 0.125, -0.125)
@dataclass(frozen=True)
class MxFp4DataConfig:
valid_rows: int
valid_cols: int
case_suffix: Optional[str]
dtype: object
scale_alg: str = MX_SCALE_ALG_OCP
def get_mx_src_dtype_key(dtype):
dtype = np.dtype(dtype)
if dtype == np.dtype(np.float16):
return MX_SRC_FP16
if dtype == np.dtype(bfloat16):
return MX_SRC_BF16
if dtype == np.dtype(np.float32):
return MX_SRC_FP32
raise ValueError(f"Unsupported MX source dtype: {dtype}")
def get_mx_src_dtype(dtype_key):
return {MX_SRC_FP16: np.float16, MX_SRC_BF16: bfloat16, MX_SRC_FP32: np.float32}[dtype_key]
def make_mxfp8_fp16_boundary_pattern():
return np.array(
[
0x0000,
0x8000,
0x0001,
0x8001,
0x03FF,
0x83FF,
0x0400,
0x8400,
0x5EFF,
0x5F00,
0x5F01,
0x6300,
0x6301,
0x7BFF,
0x7C00,
0x7E00,
],
dtype=np.uint16,
).view(np.float16)
def make_mxfp8_bf16_boundary_pattern():
return np.array(
[
0.0,
-0.0,
np.float32(2.0**-133),
-np.float32(2.0**-133),
np.float32(2.0**-126),
-np.float32(2.0**-126),
446.0,
-446.0,
448.0,
-448.0,
450.0,
-450.0,
896.0,
898.0,
np.inf,
np.nan,
],
dtype=np.float32,
).astype(bfloat16)
def make_mxfp8_fp32_boundary_pattern():
return np.array(
[
0.0,
-0.0,
np.float32(2.0**-149),
-np.float32(2.0**-149),
np.float32(2.0**-130),
-np.float32(2.0**-130),
np.nextafter(np.float32(448.0), np.float32(0.0)),
-np.nextafter(np.float32(448.0), np.float32(0.0)),
np.float32(448.0),
-np.float32(448.0),
np.nextafter(np.float32(448.0), np.float32(np.inf)),
-np.nextafter(np.float32(448.0), np.float32(-np.inf)),
np.float32(896.0),
np.nextafter(np.float32(896.0), np.float32(np.inf)),
np.inf,
np.nan,
],
dtype=np.float32,
)
MX_BOUNDARY_PATTERN_BUILDERS = {
(MX_DST_MXFP8, MX_SRC_FP16, MX_SCALE_ALG_ANY): make_mxfp8_fp16_boundary_pattern,
(MX_DST_MXFP8, MX_SRC_BF16, MX_SCALE_ALG_ANY): make_mxfp8_bf16_boundary_pattern,
(MX_DST_MXFP8, MX_SRC_FP32, MX_SCALE_ALG_ANY): make_mxfp8_fp32_boundary_pattern,
}
def get_mx_boundary_pattern(dst_format, src_dtype, scale_alg=MX_SCALE_ALG_ANY):
dtype_key = get_mx_src_dtype_key(src_dtype)
for key in ((dst_format, dtype_key, scale_alg), (dst_format, dtype_key, MX_SCALE_ALG_ANY)):
pattern_builder = MX_BOUNDARY_PATTERN_BUILDERS.get(key)
if pattern_builder is not None:
return pattern_builder()
raise ValueError(f"Unsupported MX boundary pattern: dst={dst_format}, src={dtype_key}, scale_alg={scale_alg}")
def fill_mx_boundary_groups(total, src_dtype, pattern, group_size=MX_BOUNDARY_GROUP_SIZE):
values = np.zeros(total, dtype=src_dtype)
for begin in range(0, total, group_size):
end = min(begin + group_size, total)
if begin == 0:
continue
values[begin:end] = np.resize(pattern, end - begin).astype(src_dtype)
return values
def make_mx_boundary_values(valid_rows, valid_cols, src_dtype, dst_format, scale_alg=MX_SCALE_ALG_ANY):
dtype_key = get_mx_src_dtype_key(src_dtype)
src_dtype = get_mx_src_dtype(dtype_key)
pattern = get_mx_boundary_pattern(dst_format, src_dtype, scale_alg=scale_alg)
values = fill_mx_boundary_groups(valid_rows * valid_cols, src_dtype, pattern)
return values.reshape(valid_rows, valid_cols).astype(src_dtype)
def make_mxfp8_boundary_values(valid_rows, valid_cols, dtype, scale_alg=MX_SCALE_ALG_ANY):
return make_mx_boundary_values(valid_rows, valid_cols, dtype, MX_DST_MXFP8, scale_alg=scale_alg)
def fp16_to_mxfp8(valid_rows, valid_cols, mode, scale_alg="ocp", case_suffix=None):
padded_cols = ((valid_cols + 31) // 32) * 32
if case_suffix == "boundary":
src_fp16 = make_mxfp8_boundary_values(valid_rows, valid_cols, np.float16, scale_alg=scale_alg)
elif case_suffix is not None and "_exp2d_fuzz" in case_suffix:
src_fp16 = make_mx_exp2d_fuzz_values(
valid_rows, valid_cols, np.float16, get_exp2d_fuzz_seed(case_suffix), max_abs=448.0
)
else:
mags = np.random.lognormal(mean=0.0, sigma=2.0, size=(valid_rows, valid_cols))
signs = np.where(np.random.rand(valid_rows, valid_cols) < 0.5, -1.0, 1.0)
src_fp32 = (mags * signs).astype(np.float32)
src_fp32 = np.clip(src_fp32, -6e4, 6e4)
src_fp16 = src_fp32.astype(np.float16)
src_fp16.tofile("input.bin")
pad_value = np.float16(0.0)
padded_src = np.full((valid_rows, padded_cols), pad_value, dtype=np.float16)
padded_src[:, :valid_cols] = src_fp16
_, _, data_fp8 = quant_fp16_to_e4m3(padded_src, mode=mode, scale_alg=scale_alg)
if padded_cols != valid_cols and mode == "nd":
data_fp8_valid = data_fp8.reshape(valid_rows, padded_cols)[:, :valid_cols].copy()
data_fp8_valid.tofile("golden_fp8.bin")
return
def make_mxfp4_exp_random_values(total, seed):
rng = np.random.default_rng(seed)
exponents = rng.integers(-24, 16, size=total, dtype=np.int32)
mantissas = rng.uniform(1.0, 2.0, size=total).astype(np.float32)
signs = np.where(rng.random(total) < 0.5, -1.0, 1.0).astype(np.float32)
values = np.ldexp(mantissas, exponents).astype(np.float32) * signs
values = np.clip(values, -65504.0, 65504.0)
zero_count = max(1, total // 32)
values[rng.choice(total, size=zero_count, replace=False)] = np.float32(0.0)
return values.astype(np.float16)
def make_mxfp4_exp_random_values_bf16(total, seed):
rng = np.random.default_rng(seed)
exponents = rng.integers(-133, 128, size=total, dtype=np.int32)
mantissas = rng.uniform(1.0, 2.0, size=total).astype(np.float32)
signs = np.where(rng.random(total) < 0.5, -1.0, 1.0).astype(np.float32)
values = np.ldexp(mantissas, exponents).astype(np.float32) * signs
zero_count = max(1, total // 32)
values[rng.choice(total, size=zero_count, replace=False)] = np.float32(0.0)
return values.astype(bfloat16)
def make_mxfp4_rounding_values(dtype):
return np.array(
[
4.0,
-4.0,
3.75,
-3.75,
3.5,
-3.5,
3.0,
-3.0,
2.5,
-2.5,
2.25,
-2.25,
2.0,
-2.0,
1.75,
-1.75,
1.5,
-1.5,
1.25,
-1.25,
*MXFP4_LOW_MAGNITUDE_VALUES,
],
dtype=dtype,
)
def next_up_for_dtype(value, dtype):
if dtype == np.float16:
return np.nextafter(np.float16(value), np.float16(np.inf)).astype(np.float16)
if dtype == bfloat16:
bits = np.uint16(np.frombuffer(np.float32(value).tobytes(), dtype=np.uint32)[0] >> np.uint32(16))
return bf16_bits_to_float32(np.array([bits + np.uint16(1)], dtype=np.uint16))[0].astype(np.float32)
return np.nextafter(np.float32(value), np.float32(np.inf))
def make_mxfp4_nv_boundary_values(total, dtype, patterns):
next_1p5 = next_up_for_dtype(1.5, dtype)
next_3 = next_up_for_dtype(3.0, dtype)
next_6 = next_up_for_dtype(6.0, dtype)
special_values = patterns["special_values"]
subnormal_values = patterns["subnormal_values"]
random_values = patterns["exp_random_func"](32, seed=20260513)
group_patterns = [
[0.0, -0.0],
subnormal_values,
[1.5, -1.5, 1.0, -1.0, 0.75, -0.75],
[next_1p5, -next_1p5, 1.5, -1.5, 1.0, -1.0],
[3.0, -3.0, 2.5, -2.5, 1.5, -1.5],
[next_3, -next_3, 3.0, -3.0, 1.5, -1.5],
[6.0, -6.0, 4.0, -4.0, 3.0, -3.0],
[next_6, -next_6, 6.0, -6.0, 3.0, -3.0],
[-6.0, -4.0, -3.0, -1.5, -0.75, -0.0],
make_mxfp4_rounding_values(dtype),
[np.inf, 6.0, -6.0, 3.0, -3.0, 0.0],
[-np.inf, 6.0, -6.0, 3.0, -3.0, -0.0],
[np.nan, 6.0, -6.0, 3.0, -3.0, 0.0],
special_values,
random_values,
[3.75, -3.75, 2.25, -2.25, 1.25, -1.25, 0.375, -0.375],
]
values = np.zeros(total, dtype=dtype)
for group in range((total + MX_BOUNDARY_GROUP_SIZE - 1) // MX_BOUNDARY_GROUP_SIZE):
begin = group * MX_BOUNDARY_GROUP_SIZE
end = min(begin + MX_BOUNDARY_GROUP_SIZE, total)
pattern = np.asarray(group_patterns[group % len(group_patterns)], dtype=dtype)
values[begin:end] = np.resize(pattern, end - begin)
return values
def get_exp2d_fuzz_seed(case_suffix):
return 20260600 + int(case_suffix.rsplit("fuzz", 1)[1])
def make_mxfp4_exp2d_base_pattern():
return np.array(
[
6.0,
-6.0,
4.0,
-4.0,
3.0,
-3.0,
2.5,
-2.5,
2.0,
-2.0,
1.5,
-1.5,
*MXFP4_LOW_MAGNITUDE_VALUES,
0.0,
-0.0,
3.75,
-3.75,
2.25,
-2.25,
1.25,
-1.25,
],
dtype=np.float32,
)
def make_mx_exp2d_fuzz_group_pattern(group, base_pattern, rng, max_abs):
scale = np.ldexp(np.float32(1.0), int((group % 17) - 8))
pattern = base_pattern * scale
if group % 5 == 0:
pattern[:8] = np.array(
[0.0, -0.0, max_abs, -max_abs, max_abs / 2.0, -max_abs / 2.0, 1.0, -1.0], dtype=np.float32
)
if group % 7 == 0:
mags = rng.lognormal(mean=0.0, sigma=2.0, size=8).astype(np.float32)
signs = np.where(rng.random(8) < 0.5, -1.0, 1.0).astype(np.float32)
pattern[16:24] = mags * signs
return pattern
def make_mx_exp2d_fuzz_values(valid_rows, valid_cols, dtype, seed, max_abs):
rng = np.random.default_rng(seed)
groups_per_row = valid_cols // MX_BOUNDARY_GROUP_SIZE
base_pattern = make_mxfp4_exp2d_base_pattern()
values = np.zeros((valid_rows, valid_cols), dtype=np.float32)
for row in range(valid_rows):
for col_group in range(groups_per_row):
group = row * groups_per_row + col_group
col = col_group * MX_BOUNDARY_GROUP_SIZE
end_col = col + MX_BOUNDARY_GROUP_SIZE
values[row, col:end_col] = make_mx_exp2d_fuzz_group_pattern(group, base_pattern, rng, max_abs)
clip_abs = 60000.0 if dtype == np.float16 else max_abs
values = np.clip(values, -clip_abs, clip_abs)
return values.astype(dtype)
def make_mxfp4_static4x128_exp2d_values(dtype):
group_maxes = np.array([6.0, 12.0, 24.0, 48.0, 96.0, 192.0, 384.0, 768.0], dtype=np.float32)
scaled_pattern = make_mxfp4_exp2d_base_pattern()
values = np.zeros((2, 128), dtype=np.float32)
for group, group_max in enumerate(group_maxes):
row = group // 4
col = (group % 4) * MX_BOUNDARY_GROUP_SIZE
end_col = col + MX_BOUNDARY_GROUP_SIZE
values[row, col:end_col] = scaled_pattern * (group_max / 6.0)
return values.astype(dtype)
def make_mxfp4_e2m1_data(config, patterns):
total = config.valid_rows * config.valid_cols
dtype = config.dtype
case_suffix = config.case_suffix
special_values = patterns["special_values"]
subnormal_values = patterns["subnormal_values"]
rounding_values = make_mxfp4_rounding_values(dtype)
exp_random_func = patterns["exp_random_func"]
if case_suffix == "special":
values = np.resize(special_values, total)
elif case_suffix == "inf_only":
values = np.resize(np.array([np.inf, -np.inf], dtype=dtype), total)
elif case_suffix == "subnormal":
values = np.resize(subnormal_values, total)
elif case_suffix == "rounding":
values = np.resize(rounding_values, total)
elif case_suffix == "boundary":
if config.scale_alg == MX_SCALE_ALG_NV:
values = make_mxfp4_nv_boundary_values(total, dtype, patterns)
else:
values = np.resize(special_values, total)
elif case_suffix == "exp_random_a":
values = exp_random_func(total, seed=20260508)
elif case_suffix == "exp_random_b":
values = exp_random_func(total, seed=20260509)
elif case_suffix == "normal":
rng = np.random.default_rng(20260512)
values = rng.uniform(-6.0, 6.0, size=total).astype(dtype)
elif case_suffix == "mixed":
values = np.zeros(total, dtype=dtype)
group_patterns = [special_values, subnormal_values, rounding_values, exp_random_func(32, seed=20260510)]
for group in range((total + 31) // 32):
begin = group * 32
end = min(begin + 32, total)
pattern = group_patterns[group % len(group_patterns)]
values[begin:end] = np.resize(pattern, end - begin)
elif case_suffix == "static4x128_exp2d":
values = make_mxfp4_static4x128_exp2d_values(dtype).reshape(-1)
elif case_suffix is not None and "_exp2d_fuzz" in case_suffix:
values = make_mx_exp2d_fuzz_values(
config.valid_rows, config.valid_cols, dtype, get_exp2d_fuzz_seed(case_suffix), max_abs=768.0
).reshape(-1)
else:
values = exp_random_func(total, seed=20260511)
return values.reshape(config.valid_rows, config.valid_cols).astype(dtype)
def make_mxfp4_e2m1_fp16_data(valid_rows, valid_cols, case_suffix, scale_alg=MX_SCALE_ALG_OCP):
special_values = np.array(
[0.0, -0.0, np.inf, -np.inf, np.nan, 65504.0, -65504.0, 6.0, -6.0, 4.0, -4.0, 1.5, -1.5, 0.5, -0.5, 0.25],
dtype=np.float16,
)
subnormal_bits = np.array(
[
0x0001,
0x8001,
0x0002,
0x8002,
0x0003,
0x8003,
0x0010,
0x8010,
0x0100,
0x8100,
0x0200,
0x8200,
0x03FF,
0x83FF,
0x0400,
0x8400,
],
dtype=np.uint16,
)
patterns = {
"special_values": special_values,
"subnormal_values": subnormal_bits.view(np.float16),
"exp_random_func": make_mxfp4_exp_random_values,
}
config = MxFp4DataConfig(
valid_rows=valid_rows, valid_cols=valid_cols, case_suffix=case_suffix, dtype=np.float16, scale_alg=scale_alg
)
return make_mxfp4_e2m1_data(config, patterns)
def make_mxfp4_e2m1_bf16_data(valid_rows, valid_cols, case_suffix, scale_alg=MX_SCALE_ALG_OCP):
special_values = np.array(
[
0.0,
-0.0,
np.inf,
-np.inf,
np.nan,
3.38953139e38,
-3.38953139e38,
6.0,
-6.0,
4.0,
-4.0,
1.5,
-1.5,
0.5,
-0.5,
0.25,
],
dtype=bfloat16,
)
subnormal_bits = np.array(
[
0x0001,
0x8001,
0x0002,
0x8002,
0x0003,
0x8003,
0x0010,
0x8010,
0x0040,
0x8040,
0x007F,
0x807F,
0x0080,
0x8080,
0x0100,
0x8100,
],
dtype=np.uint16,
)
patterns = {
"special_values": special_values,
"subnormal_values": subnormal_bits.view(bfloat16),
"exp_random_func": make_mxfp4_exp_random_values_bf16,
}
config = MxFp4DataConfig(
valid_rows=valid_rows, valid_cols=valid_cols, case_suffix=case_suffix, dtype=bfloat16, scale_alg=scale_alg
)
return make_mxfp4_e2m1_data(config, patterns)
def fp16_to_mxfp4_e2m1(valid_rows, valid_cols, mode, case_suffix=None, scale_alg="ocp"):
padded_cols = ((valid_cols + 31) // 32) * 32
src_fp16 = make_mxfp4_e2m1_fp16_data(valid_rows, valid_cols, case_suffix, scale_alg=scale_alg)
src_fp16.tofile("input.bin")
padded_src = np.zeros((valid_rows, padded_cols), dtype=np.float16)
padded_src[:, :valid_cols] = src_fp16
_, _, packed = quant_fp16_to_e2m1(padded_src, scale_alg=scale_alg)
if padded_cols != valid_cols and mode == "nd":
packed[:, : ((valid_cols + 1) // 2)].copy().tofile("golden_fp4.bin")
return
def quant_bf16_to_e2m1(src, scale_alg="ocp"):
data_abs = np.abs(src).astype(bfloat16)
data_grouped = data_abs.reshape(-1, 32)
with np.errstate(invalid="ignore"):
group_max_bf16 = np.max(data_grouped, axis=1)
if scale_alg == "nv":
e8m0, scaling_bf16 = nv_maxes_to_e2m1(group_max_bf16.astype(np.float32))
else:
e8m0, scaling_bf16 = fp16_maxes_to_e2m1(group_max_bf16.astype(np.float32))
scaling_bf16 = scaling_bf16.astype(bfloat16)
scaled = (src.reshape(-1, 32) * scaling_bf16).astype(bfloat16).reshape(src.shape)
codes = np.vectorize(encode_e2m1_magic_scalar, otypes=[np.uint8])(scaled.astype(np.float32))
packed = pack_fp4_e2m1(codes, src.shape[0], src.shape[1])
e8m0.tofile("golden_e8m0.bin")
scaling_bf16.astype(np.float32).tofile("scaling_e2m1.bin")
packed.tofile("golden_fp4.bin")
return e8m0, scaling_bf16, packed
def bf16_to_mxfp4_e2m1(valid_rows, valid_cols, mode, case_suffix=None, scale_alg="ocp"):
padded_cols = ((valid_cols + 31) // 32) * 32
src_bf16 = make_mxfp4_e2m1_bf16_data(valid_rows, valid_cols, case_suffix, scale_alg=scale_alg)
src_bf16.tofile("input.bin")
padded_src = np.zeros((valid_rows, padded_cols), dtype=bfloat16)
padded_src[:, :valid_cols] = src_bf16
_, _, packed = quant_bf16_to_e2m1(padded_src, scale_alg=scale_alg)
if padded_cols != valid_cols and mode == "nd":
packed[:, : ((valid_cols + 1) // 2)].copy().tofile("golden_fp4.bin")
return
def bf16_to_mxfp8(valid_rows, valid_cols, mode, scale_alg="ocp", case_suffix=None):
padded_cols = ((valid_cols + 31) // 32) * 32
if case_suffix == "boundary":
src_bf16 = make_mxfp8_boundary_values(valid_rows, valid_cols, bfloat16, scale_alg=scale_alg)
elif case_suffix is not None and "_exp2d_fuzz" in case_suffix:
src_bf16 = make_mx_exp2d_fuzz_values(
valid_rows, valid_cols, bfloat16, get_exp2d_fuzz_seed(case_suffix), max_abs=448.0
)
else:
mags = np.random.lognormal(mean=0.0, sigma=2.0, size=(valid_rows, valid_cols))
signs = np.where(np.random.rand(valid_rows, valid_cols) < 0.5, -1.0, 1.0)
src_fp32 = (mags * signs).astype(np.float32)
src_fp32 = np.clip(src_fp32, -1e4, 1e4)
src_bf16 = src_fp32.astype(bfloat16)
src_bf16.tofile("input.bin")
pad_value = bfloat16(0.0)
padded_src = np.full((valid_rows, padded_cols), pad_value, dtype=bfloat16)
padded_src[:, :valid_cols] = src_bf16
e8m0, scaling, data_fp8, group_max = quant_bf16_to_e4m3(padded_src, mode=mode, scale_alg=scale_alg)
if padded_cols != valid_cols and mode == "nd":
data_fp8_valid = data_fp8.reshape(valid_rows, padded_cols)[:, :valid_cols].copy()
data_fp8_valid.tofile("golden_fp8.bin")
return
def fp32_to_int8_sym(valid_rows, valid_cols, mode):
src_fp32 = np.random.uniform(low=-2, high=2, size=(valid_rows, valid_cols)).astype(np.float32)
src_fp32.tofile("input.bin")
offset = np.zeros((valid_rows, 1), dtype=np.float32)
scale = np.max(np.abs(src_fp32), axis=1, keepdims=True) / 127.0
scale = scale.astype(np.float32)
inv_scale = np.where(scale != 0, 1.0 / scale, 0.0).astype(np.float32)
inv_scale.tofile("inv_scale_fp32.bin")
offset.tofile("offset_fp32.bin")
src_fp32_scaled = src_fp32 * inv_scale
src_fp32_rounded = np.round(src_fp32_scaled).astype(np.float32)
src_fp16 = src_fp32_rounded.astype(np.float16)
src_s8 = np.clip(src_fp16, -128, 127).astype(np.int8)
src_s8.tofile("golden_s8.bin")
return src_fp32, src_s8
def fp32_to_int8_asym(valid_rows, valid_cols, mode):
src_fp32 = np.random.uniform(low=-2, high=2, size=(valid_rows, valid_cols)).astype(np.float32)
src_fp32.tofile("input.bin")
src_fp32_rowmin = np.min(src_fp32, axis=1, keepdims=True)
src_fp32_rowmax = np.max(src_fp32, axis=1, keepdims=True)
scale = (src_fp32_rowmax - src_fp32_rowmin) / 255.0
scale = scale.astype(np.float32)
inv_scale = np.where(scale != 0, 1.0 / scale, 0.0).astype(np.float32)
inv_scale.tofile("inv_scale_fp32.bin")
zero_point = np.clip(np.round(-src_fp32_rowmin / scale), 0, 255).astype(np.float32)
zero_point.tofile("offset_fp32.bin")
src_fp32_out = src_fp32 * inv_scale + zero_point
src_fp32_rounded = np.round(src_fp32_out).astype(np.float32)
src_fp16_out = src_fp32_rounded.astype(np.float16)
src_u8 = np.clip(src_fp16_out, 0, 255).astype(np.uint8)
src_u8.tofile("golden_u8.bin")
return src_fp32, src_u8
def fp32_to_mxfp8(valid_rows, valid_cols, mode, scale_alg="ocp", case_suffix=None):
padded_cols = ((valid_cols + 31) // 32) * 32
if case_suffix == "boundary":
src_fp32 = make_mxfp8_boundary_values(valid_rows, valid_cols, np.float32, scale_alg=scale_alg)
elif case_suffix is not None and "_exp2d_fuzz" in case_suffix:
src_fp32 = make_mx_exp2d_fuzz_values(
valid_rows, valid_cols, np.float32, get_exp2d_fuzz_seed(case_suffix), max_abs=448.0
)
else:
mags = np.random.lognormal(mean=0.0, sigma=2.0, size=(valid_rows, valid_cols))
signs = np.where(np.random.rand(valid_rows, valid_cols) < 0.5, -1.0, 1.0)
src_fp32 = (mags * signs).astype(np.float32)
src_fp32 = np.clip(src_fp32, -1e8, 1e8)
src_fp32.tofile("input.bin")
pad_value = np.float32(0.0)
padded_src = np.full((valid_rows, padded_cols), pad_value, dtype=np.float32)
padded_src[:, :valid_cols] = src_fp32
e8m0, scaling, data_fp8, group_max = quant_fp32_to_e4m3(padded_src, mode=mode, scale_alg=scale_alg)
if padded_cols != valid_cols and mode == "nd":
data_fp8_valid = data_fp8.reshape(valid_rows, padded_cols)[:, :valid_cols].copy()
data_fp8_valid.tofile("golden_fp8.bin")
return
def gen_golden_data_tquant(case_name, param):
dtype = param.dtype
valid_rows, valid_cols = [param.valid_rows, param.valid_cols]
mode = param.mode
out_dtype_str = param.out_dtype_str
if out_dtype_str == "int8_sym":
fp32_to_int8_sym(valid_rows, valid_cols, mode)
elif out_dtype_str == "int8_asym":
fp32_to_int8_asym(valid_rows, valid_cols, mode)
elif out_dtype_str == "mxfp4_e2m1":
if dtype == bfloat16:
bf16_to_mxfp4_e2m1(valid_rows, valid_cols, mode, case_suffix=param.case_suffix, scale_alg=param.scale_alg)
else:
fp16_to_mxfp4_e2m1(valid_rows, valid_cols, mode, case_suffix=param.case_suffix, scale_alg=param.scale_alg)
elif dtype == bfloat16:
bf16_to_mxfp8(valid_rows, valid_cols, mode, scale_alg=param.scale_alg, case_suffix=param.case_suffix)
elif dtype == np.float16:
fp16_to_mxfp8(valid_rows, valid_cols, mode, scale_alg=param.scale_alg, case_suffix=param.case_suffix)
else:
fp32_to_mxfp8(valid_rows, valid_cols, mode, scale_alg=param.scale_alg, case_suffix=param.case_suffix)
return
class TQuantParams:
def __init__(
self, out_dtype_str, valid_rows, valid_cols, mode="nd", dtype=np.float32, case_suffix=None, scale_alg="ocp"
):
self.valid_rows = valid_rows
self.valid_cols = valid_cols
self.dtype = dtype
self.mode = mode
self.case_suffix = case_suffix
self.scale_alg = scale_alg
self.out_dtype_str = {"s8": "int8_sym", "mxfp8": "mxfp8", "mxfp4_e2m1": "mxfp4_e2m1", "u8": "int8_asym"}[
out_dtype_str
]
self.dtype_str = {np.float32: "fp32", bfloat16: "bf16", np.float16: "fp16"}[self.dtype]
def generate_case_name(param):
suffix = f"_{param.case_suffix}" if param.case_suffix is not None else ""
alg_suffix = "_nv" if param.scale_alg == "nv" else ""
return (
f"TQUANTTEST.case_{param.out_dtype_str}{alg_suffix}_{param.dtype_str}_"
f"{param.valid_rows}x{param.valid_cols}{suffix}_{param.mode}"
)
def make_exp2d_fuzz_params():
return [TQuantParams("mxfp8", 1, 64, mode="nd", dtype=bfloat16, case_suffix="static3x128_exp2d_fuzz01")]
if __name__ == "__main__":
script_dir = os.path.dirname(os.path.abspath(__file__))
testcases_dir = os.path.join(script_dir, "testcases")
if not os.path.exists(testcases_dir):
os.makedirs(testcases_dir)
case_params_list = [
TQuantParams("mxfp8", 32, 32, mode="nd"),
TQuantParams("mxfp8", 32, 64, mode="nd"),
TQuantParams("mxfp8", 64, 128, mode="nd"),
TQuantParams("mxfp8", 128, 128, mode="nd"),
TQuantParams("mxfp8", 15, 32, mode="nd"),
TQuantParams("mxfp8", 7, 64, mode="nd"),
TQuantParams("mxfp8", 33, 64, mode="nd"),
TQuantParams("mxfp8", 13, 192, mode="nd"),
TQuantParams("mxfp8", 32, 128, mode="nd", scale_alg="nv"),
TQuantParams("mxfp8", 2, 256, mode="nd", case_suffix="boundary", scale_alg="nv"),
TQuantParams("mxfp8", 32, 64, mode="nz"),
TQuantParams("mxfp8", 64, 128, mode="nz"),
TQuantParams("mxfp8", 64, 256, mode="nz"),
TQuantParams("mxfp8", 64, 512, mode="nz"),
TQuantParams("mxfp8", 128, 128, mode="nz"),
TQuantParams("s8", 64, 128, mode="nd"),
TQuantParams("s8", 128, 128, mode="nd"),
TQuantParams("s8", 256, 128, mode="nd"),
TQuantParams("u8", 64, 128, mode="nd"),
TQuantParams("u8", 128, 128, mode="nd"),
TQuantParams("u8", 256, 128, mode="nd"),
TQuantParams("u8", 32, 72, mode="nd"),
TQuantParams("mxfp8", 32, 128, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 64, 128, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 128, 128, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 14, 16, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 7, 48, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 18, 136, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 1, 32, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 2, 16, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 4, 16, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 8, 16, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 16, 16, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 3, 32, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 5, 96, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 1, 16, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 4, 256, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 4, 512, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 3, 256, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 5, 256, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 18, 138, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 1, 192, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 1, 198, mode="nd", dtype=bfloat16),
TQuantParams("mxfp8", 32, 128, mode="nz", dtype=bfloat16),
TQuantParams("mxfp8", 64, 128, mode="nz", dtype=bfloat16),
TQuantParams("mxfp8", 128, 128, mode="nz", dtype=bfloat16),
TQuantParams("mxfp8", 32, 128, mode="nd", dtype=bfloat16, scale_alg="nv"),
TQuantParams("mxfp8", 64, 128, mode="nd", dtype=bfloat16, scale_alg="nv"),
TQuantParams("mxfp8", 128, 128, mode="nd", dtype=bfloat16, scale_alg="nv"),
TQuantParams("mxfp8", 7, 48, mode="nd", dtype=bfloat16, scale_alg="nv"),
TQuantParams("mxfp8", 2, 256, mode="nd", dtype=bfloat16, case_suffix="boundary", scale_alg="nv"),
TQuantParams("mxfp8", 32, 128, mode="nd", dtype=np.float16),
TQuantParams("mxfp8", 64, 128, mode="nd", dtype=np.float16),
TQuantParams("mxfp8", 128, 128, mode="nd", dtype=np.float16),
TQuantParams("mxfp8", 4, 256, mode="nd", dtype=np.float16),
TQuantParams("mxfp8", 11, 640, mode="nd", dtype=np.float16),
TQuantParams("mxfp8", 32, 128, mode="nd", dtype=np.float16, scale_alg="nv"),
TQuantParams("mxfp8", 64, 128, mode="nd", dtype=np.float16, scale_alg="nv"),
TQuantParams("mxfp8", 128, 128, mode="nd", dtype=np.float16, scale_alg="nv"),
TQuantParams("mxfp8", 2, 256, mode="nd", dtype=np.float16, case_suffix="boundary", scale_alg="nv"),
*make_exp2d_fuzz_params(),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=np.float16, case_suffix="special"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=np.float16, case_suffix="inf_only"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=np.float16, case_suffix="subnormal"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=np.float16, case_suffix="rounding"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=np.float16, case_suffix="exp_random_a"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=np.float16, case_suffix="exp_random_b"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=np.float16, case_suffix="mixed"),
TQuantParams("mxfp4_e2m1", 32, 1024, mode="nd", dtype=np.float16, case_suffix="mixed"),
TQuantParams("mxfp4_e2m1", 32, 1024, mode="nd", dtype=np.float16, case_suffix="normal"),
TQuantParams("mxfp4_e2m1", 2, 256, mode="nd", dtype=np.float16, case_suffix="boundary", scale_alg="nv"),
TQuantParams("mxfp4_e2m1", 2, 256, mode="nd", dtype=np.float16, case_suffix="rounding", scale_alg="nv"),
TQuantParams("mxfp4_e2m1", 2, 256, mode="nd", dtype=np.float16, case_suffix="mixed", scale_alg="nv"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=bfloat16, case_suffix="special"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=bfloat16, case_suffix="inf_only"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=bfloat16, case_suffix="subnormal"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=bfloat16, case_suffix="rounding"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=bfloat16, case_suffix="exp_random_a"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=bfloat16, case_suffix="exp_random_b"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=bfloat16, case_suffix="mixed"),
TQuantParams("mxfp4_e2m1", 32, 1024, mode="nd", dtype=bfloat16, case_suffix="mixed"),
TQuantParams("mxfp4_e2m1", 32, 1024, mode="nd", dtype=bfloat16, case_suffix="normal"),
TQuantParams("mxfp4_e2m1", 2, 256, mode="nd", dtype=bfloat16, case_suffix="boundary", scale_alg="nv"),
TQuantParams("mxfp4_e2m1", 2, 256, mode="nd", dtype=bfloat16, case_suffix="rounding", scale_alg="nv"),
TQuantParams("mxfp4_e2m1", 2, 256, mode="nd", dtype=bfloat16, case_suffix="mixed", scale_alg="nv"),
TQuantParams("mxfp4_e2m1", 2, 128, mode="nd", dtype=bfloat16, case_suffix="static4x128_exp2d", scale_alg="nv"),
TQuantParams("mxfp8", 32, 128, mode="nz", dtype=np.float16),
TQuantParams("mxfp8", 64, 128, mode="nz", dtype=np.float16),
TQuantParams("mxfp8", 128, 128, mode="nz", dtype=np.float16),
]
for param in case_params_list:
case_name = generate_case_name(param)
if not os.path.exists(case_name):
os.makedirs(case_name)
original_dir = os.getcwd()
os.chdir(case_name)
gen_golden_data_tquant(case_name, param)
os.chdir(original_dir)