import argparse
import os
from typing import Dict, Tuple
import torch
_BLOCK_SIZE = 32
_EPSILON = 1e-12
_MIN_SCALE_EXP = -128
_MAX_SCALE_EXP = 127
_FP8_FORMATS = {
"E4M3": {
"torch_dtype": torch.float8_e4m3fn,
"exp_bits": 4,
"mantissa_bits": 3,
"bias": 7,
"emax": 8,
"max_value": 448.0,
"min_value": -448.0,
},
"E5M2": {
"torch_dtype": torch.float8_e5m2,
"exp_bits": 5,
"mantissa_bits": 2,
"bias": 15,
"emax": 15,
"max_value": 57344.0,
"min_value": -57344.0,
},
}
_FP4_FORMATS: Dict[str, Dict[str, float]] = {
"E2M1": {
"exp_bits": 2,
"mantissa_bits": 1,
"bias": 1,
"emax": 2,
"max_value": 6.0,
"min_value": -6.0,
},
"E1M2": {
"exp_bits": 1,
"mantissa_bits": 2,
"bias": 1,
"emax": 0,
"max_value": 1.75,
"min_value": -1.75,
},
}
def _build_e4m3_lut() -> torch.Tensor:
bias = _FP8_FORMATS["E4M3"]["bias"]
fp8_max = _FP8_FORMATS["E4M3"]["max_value"]
values = []
for i in range(256):
if i < 128:
sign, val = 1, i
else:
sign, val = -1, i - 128
if val == 0:
v = 0.0
elif val == 127:
v = sign * fp8_max
else:
exp = (val >> 3) & 0x0F
mantissa = val & 0x07
if exp == 0:
v = (mantissa / 8.0) * (2.0 ** (1 - bias))
else:
v = (1.0 + mantissa / 8.0) * (2.0 ** (exp - bias))
v *= sign
v = max(min(v, fp8_max), -fp8_max)
values.append(v)
return torch.tensor(values, dtype=torch.float32)
def _build_e5m2_lut() -> torch.Tensor:
bias = _FP8_FORMATS["E5M2"]["bias"]
fp8_max = _FP8_FORMATS["E5M2"]["max_value"]
values = []
for i in range(256):
if i < 128:
sign, val = 1, i
else:
sign, val = -1, i - 128
if val == 0:
v = 0.0
elif 124 <= val <= 127:
v = sign * fp8_max
else:
exp = (val >> 2) & 0x1F
mantissa = val & 0x03
if exp == 0:
v = (mantissa / 4.0) * (2.0 ** (1 - bias))
else:
v = (1.0 + mantissa / 4.0) * (2.0 ** (exp - bias))
v *= sign
v = max(min(v, fp8_max), -fp8_max)
values.append(v)
return torch.tensor(values, dtype=torch.float32)
_FP8_LUT_BUILDERS = {"E4M3": _build_e4m3_lut, "E5M2": _build_e5m2_lut}
_FP8_LUT_CACHE = {}
_FP8_LUT_POS_CACHE = {}
def _get_fp8_lut(format_name: str) -> torch.Tensor:
if format_name not in _FP8_LUT_CACHE:
_FP8_LUT_CACHE[format_name] = _FP8_LUT_BUILDERS[format_name]()
return _FP8_LUT_CACHE[format_name]
def _get_fp8_lut_pos(format_name: str) -> torch.Tensor:
if format_name not in _FP8_LUT_POS_CACHE:
full = _get_fp8_lut(format_name)
pos = full[:128].contiguous()
diffs = pos[1:] - pos[:-1]
if (diffs < 0).any():
raise AssertionError(
f"{format_name} positive LUT half is not non-decreasing")
_FP8_LUT_POS_CACHE[format_name] = pos
return _FP8_LUT_POS_CACHE[format_name]
def _build_fp4_lut(format_name: str) -> torch.Tensor:
config = _FP4_FORMATS[format_name]
exp_bits = int(config["exp_bits"])
mantissa_bits = int(config["mantissa_bits"])
bias = float(config["bias"])
values = []
for i in range(16):
sign = (i >> 3) & 0x01
exp = (i >> mantissa_bits) & ((1 << exp_bits) - 1)
mantissa = i & ((1 << mantissa_bits) - 1)
if exp == 0:
if mantissa == 0:
value = 0.0
else:
value = (mantissa / float(1 << mantissa_bits)) * (2.0 ** (1.0 - bias))
else:
value = (1.0 + mantissa / float(1 << mantissa_bits)) * (2.0 ** (float(exp) - bias))
if sign == 1:
value = -value
values.append(value)
return torch.tensor(values, dtype=torch.float32)
_FP4_LUT = {
"E2M1": _build_fp4_lut("E2M1"),
"E1M2": _build_fp4_lut("E1M2"),
}
def _e8m0_exp(max_abs: torch.Tensor, emax: int,
epsilon: float = _EPSILON) -> torch.Tensor:
assert max_abs.dtype == torch.float32, max_abs.dtype
zero_mask = max_abs < epsilon
safe = torch.where(zero_mask, torch.ones_like(max_abs), max_abs)
bits = safe.contiguous().view(torch.int32)
exp_bits = (bits >> 23) & 0xFF
exp = exp_bits - 127 - emax
exp = exp.clamp(_MIN_SCALE_EXP, _MAX_SCALE_EXP)
return torch.where(zero_mask, torch.zeros_like(exp), exp)
def _vectorized_lut_quantize_fp8(scaled: torch.Tensor, format_name: str,
fp8_dtype: torch.dtype) -> torch.Tensor:
lut_pos = _get_fp8_lut_pos(format_name)
last_idx = lut_pos.numel() - 1
sign = torch.sign(scaled)
mag = scaled.abs()
upper_idx = torch.searchsorted(lut_pos, mag).clamp(max=last_idx)
lower_idx = (upper_idx - 1).clamp(min=0)
upper_val = lut_pos[upper_idx]
lower_val = lut_pos[lower_idx]
pick_lower = (mag - lower_val) <= (upper_val - mag)
chosen_mag = torch.where(pick_lower, lower_val, upper_val)
snapped_fp32 = sign * chosen_mag
zero_mask = chosen_mag == 0
snapped_fp32 = torch.where(
zero_mask, torch.zeros_like(snapped_fp32), snapped_fp32)
return snapped_fp32.to(fp8_dtype)
def _quantize_to_fp4_lut(values: torch.Tensor, format_name: str) -> Tuple[torch.Tensor, torch.Tensor]:
lut = _FP4_LUT[format_name].to(values.device)
min_value = _FP4_FORMATS[format_name]["min_value"]
max_value = _FP4_FORMATS[format_name]["max_value"]
clamped = values.clamp(min_value, max_value)
distances = (clamped.unsqueeze(-1) - lut).abs()
indices = torch.argmin(distances, dim=-1)
quantized = lut[indices]
return quantized, indices.to(torch.uint8)
def _pack_fp4_nibbles(index_matrix: torch.Tensor) -> torch.Tensor:
rows, cols = index_matrix.shape
if cols % 2 != 0:
index_matrix = torch.cat(
[index_matrix, torch.zeros((rows, 1), dtype=torch.uint8, device=index_matrix.device)],
dim=1,
)
low = index_matrix[:, 0::2]
high = index_matrix[:, 1::2] << 4
packed = low | high
return packed.to(torch.uint8)
def _quantize_fp8_axis_last(matrix: torch.Tensor, format_name: str,
block_size: int = _BLOCK_SIZE
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
M, N = matrix.shape
fmt = _FP8_FORMATS[format_name]
fp8_dtype = fmt["torch_dtype"]
fp8_emax = fmt["emax"]
fp8_max = fmt["max_value"]
num_blocks = (N + block_size - 1) // block_size
padded_n = num_blocks * block_size
if padded_n != N:
padded = torch.zeros(M, padded_n, dtype=matrix.dtype)
padded[:, :N] = matrix
else:
padded = matrix
blocks = padded.view(M, num_blocks, block_size)
max_abs = blocks.abs().amax(dim=-1)
exp = _e8m0_exp(max_abs, fp8_emax)
scale = torch.exp2(exp.to(torch.float32))
scaled = blocks / scale.unsqueeze(-1)
scaled_clamped = scaled.clamp(-fp8_max, fp8_max)
quant_fp8 = _vectorized_lut_quantize_fp8(scaled_clamped, format_name, fp8_dtype)
dequant = quant_fp8.to(torch.float32) * scale.unsqueeze(-1)
if padded_n != N:
quant_fp8 = quant_fp8.reshape(M, padded_n)[:, :N].contiguous()
dequant = dequant.reshape(M, padded_n)[:, :N].contiguous()
else:
quant_fp8 = quant_fp8.reshape(M, N)
dequant = dequant.reshape(M, N)
padded_blocks = ((num_blocks + 1) // 2) * 2
if padded_blocks != num_blocks:
scale_padded = torch.ones((M, padded_blocks), dtype=torch.float32)
scale_padded[:, :num_blocks] = scale
scale = scale_padded
return quant_fp8, scale, dequant
def _quantize_fp8_axis_first(matrix: torch.Tensor, format_name: str,
block_size: int = _BLOCK_SIZE
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
qt, st, dt = _quantize_fp8_axis_last(
matrix.t().contiguous(), format_name, block_size)
return (qt.t().contiguous(),
st.t().contiguous(),
dt.t().contiguous())
def _quantize_fp8(matrix: torch.Tensor, format_name: str, axis: int,
block_size: int = _BLOCK_SIZE
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if axis == 0:
return _quantize_fp8_axis_first(matrix, format_name, block_size)
if axis == 1:
return _quantize_fp8_axis_last(matrix, format_name, block_size)
raise ValueError(f"axis must be 0 or 1, got {axis}")
def _quantize_fp4_axis_last(matrix: torch.Tensor, format_name: str,
block_size: int = _BLOCK_SIZE
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
m, n = matrix.shape
padded_n = ((n + block_size - 1) // block_size) * block_size
num_blocks = padded_n // block_size
if padded_n != n:
padded = torch.zeros((m, padded_n), dtype=matrix.dtype, device=matrix.device)
padded[:, :n] = matrix
else:
padded = matrix
blocks = padded.view(m, num_blocks, block_size)
max_abs = blocks.abs().amax(dim=-1)
exp = torch.floor(torch.log2(torch.clamp(max_abs, min=_EPSILON))) - _FP4_FORMATS[format_name]["emax"]
exp = torch.where(max_abs < _EPSILON, torch.zeros_like(exp), exp)
exp = exp.clamp(_MIN_SCALE_EXP, _MAX_SCALE_EXP)
scale = torch.pow(torch.tensor(2.0, dtype=torch.float32, device=matrix.device), exp)
scaled = blocks / scale.unsqueeze(-1)
quantized_blocks, _ = _quantize_to_fp4_lut(scaled, format_name)
dequant_blocks = quantized_blocks * scale.unsqueeze(-1)
quantized = quantized_blocks.reshape(m, padded_n)
dequantized = dequant_blocks.reshape(m, padded_n)
if padded_n != n:
quantized = quantized[:, :n].contiguous()
dequantized = dequantized[:, :n].contiguous()
padded_blocks = ((num_blocks + 1) // 2) * 2
if padded_blocks != num_blocks:
scale_padded = torch.ones((m, padded_blocks), dtype=torch.float32, device=matrix.device)
scale_padded[:, :num_blocks] = scale
scale = scale_padded
return quantized, scale, dequantized
def _quantize_fp4_axis_first(matrix: torch.Tensor, format_name: str,
block_size: int = _BLOCK_SIZE
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
quantized_t, scale_t, dequantized_t = _quantize_fp4_axis_last(
matrix.t().contiguous(), format_name, block_size)
return quantized_t.t().contiguous(), scale_t.t().contiguous(), dequantized_t.t().contiguous()
def _quantize_fp4(matrix: torch.Tensor, format_name: str, axis: int,
block_size: int = _BLOCK_SIZE
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if axis == 0:
return _quantize_fp4_axis_first(matrix, format_name, block_size)
if axis == 1:
return _quantize_fp4_axis_last(matrix, format_name, block_size)
raise ValueError(f"axis must be 0 or 1, got {axis}")
def gen_data_fp8_e4m3(row, col, axis):
matrix = torch.randn((row, col), dtype=torch.float32) * 10
quant_fp8, scale_fp32, dequant_fp32 = _quantize_fp8(matrix, "E4M3", axis)
return (quant_fp8.to(torch.float8_e4m3fn),
scale_fp32.to(torch.float8_e8m0fnu),
dequant_fp32)
def gen_data_fp8_e5m2(row, col, axis):
matrix = torch.randn((row, col), dtype=torch.float32)
quant_fp8, scale_fp32, dequant_fp32 = _quantize_fp8(matrix, "E5M2", axis)
return (quant_fp8.to(torch.float8_e5m2),
scale_fp32.to(torch.float8_e8m0fnu),
dequant_fp32)
def gen_data_fp4_e2m1(row, col, axis, trans):
matrix = torch.randn((row, col), dtype=torch.float32)
quantized_matrix, scale_matrix, dequantized_matrix = _quantize_fp4(matrix, "E2M1", axis)
if trans == 1:
quantized_matrix = quantized_matrix.t().contiguous()
_, fp4_indices = _quantize_to_fp4_lut(quantized_matrix, "E2M1")
quantized_matrix_uint8 = _pack_fp4_nibbles(fp4_indices)
return quantized_matrix_uint8, scale_matrix.to(torch.float8_e8m0fnu), dequantized_matrix
def gen_data(m, n, k, trans_a, trans_b) -> None:
os.makedirs('./input', exist_ok=True)
os.makedirs('./golden', exist_ok=True)
a_fp8, a_scale, a_fp32 = gen_data_fp8_e4m3(m, k, 1)
b_fp4, b_scale, b_fp32 = gen_data_fp4_e2m1(k, n, 0, trans_b)
a_scale = a_scale.reshape(a_scale.shape[0], a_scale.shape[1] // 2, 2)
b_scale = b_scale.reshape(b_scale.shape[0] // 2, 2, b_scale.shape[1])
if trans_a == 1:
a_scale = a_scale.permute(1, 0, 2)
if trans_b == 1:
b_scale = b_scale.permute(2, 0, 1)
else:
b_scale = b_scale.permute(0, 2, 1)
a_np = torch.tensor(a_fp8.flatten().untyped_storage(), dtype=torch.int8).numpy()
b_np = torch.tensor(b_fp4.flatten().untyped_storage(), dtype=torch.int8).numpy()
a_np.tofile('./input/a_8.bin')
b_np.tofile('./input/b_4.bin')
a_scale_np = torch.tensor(a_scale.flatten().untyped_storage(), dtype=torch.int8).numpy()
b_scale_np = torch.tensor(b_scale.flatten().untyped_storage(), dtype=torch.int8).numpy()
a_scale_np.tofile('./input/a_scale.bin')
b_scale_np.tofile('./input/b_scale.bin')
c_fp32 = a_fp32 @ b_fp32
c_np = c_fp32.numpy()
c_np.tofile('./golden/expected_data.bin')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("m", type=int)
parser.add_argument("n", type=int)
parser.add_argument("k", type=int)
parser.add_argument("trans_a", type=int)
parser.add_argument("trans_b", type=int)
args = parser.parse_args()
gen_data(args.m, args.n, args.k, args.trans_a, args.trans_b)