"""
MXFP4 (Microscaling FP4 E2M1) 完整量化与反量化实现
====================================================
包含: FP4 编码表 | MX 量化/反量化 | 位级打包/解包 | E8M0 缩放因子
支持: Baseline 和 OAS (溢出感知缩放) 两种策略
本版本修复了以下 bug:
1. FP4 E2M1 编码表中 e=3 错误返回 inf/nan (应为 ±4.0/±6.0)
2. compute_e8m0_scale_block 未按 OCP MX 规范扣除 emax_elem,
且禁用了负指数 (scale<1 的 block 被错误量化)
3. mx_quantize 去 pad 时存在多余的逻辑判断
4. mx_dequantize_fp4_full 在 last dim 非 block_size 倍数时数据错位
5. scale 通过 log2 反算并以 E8M0 biased 编码存储 (符合规范)
"""
import torch
from typing import Tuple, Dict
FP4_E2M1_VALUES = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], dtype=torch.float32
)
E8M0_BIAS = 127
FP4_E2M1_EMAX = 2
def generate_fp4_e2m1_encoding_table():
"""
生成完整的 FP4 E2M1 编码表:
每个编码 4-bit, 格式: S-E-E-M (1位符号, 2位指数, 1位尾数)
返回:
fp32_values: 所有可表示的 FP32 值 (tensor, shape=[16])
binary_encodings: 对应的 4-bit 编码 (tensor, shape=[16], dtype=uint8)
fp32_to_encoding: 字典: {fp32_value: encoding}
encoding_to_fp32: 字典: {encoding: fp32_value}
"""
encodings = []
values = []
for s in [0, 1]:
for e in range(4):
for m in range(2):
encoding = (s << 3) | (e << 1) | m
if e == 0:
val = (-1) ** s * (2.0**0) * (m / 2.0)
else:
val = (-1) ** s * (2.0 ** (e - 1)) * (1.0 + m / 2.0)
encodings.append(encoding)
values.append(val)
fp32_to_encoding = {}
encoding_to_fp32 = {}
for enc, val in zip(encodings, values):
if val not in fp32_to_encoding:
fp32_to_encoding[val] = enc
encoding_to_fp32[enc] = val
fp32_values = torch.tensor(values, dtype=torch.float32)
binary_encodings = torch.tensor(encodings, dtype=torch.uint8)
return fp32_values, binary_encodings, fp32_to_encoding, encoding_to_fp32
FP4_VALUES, FP4_ENCODINGS, FP32_TO_FP4, FP4_TO_FP32 = generate_fp4_e2m1_encoding_table()
def quantize_fp4(tensor: torch.Tensor) -> torch.Tensor:
"""
将张量量化到最近的 FP4 (E2M1) 值, 返回 FP32 模拟表示。
输入: 任意形状 FP32 张量
输出: 同形状 FP32 张量 (值是 FP4 可表示的值)
"""
with torch.no_grad():
expanded = tensor.unsqueeze(-1)
expanded_abs = torch.abs(expanded - FP4_E2M1_VALUES.to(tensor.device))
indices = torch.argmin(expanded_abs, dim=-1)
quantized = FP4_E2M1_VALUES.to(tensor.device)[indices]
return quantized
def dequantize_fp4(fp4_data: torch.Tensor) -> torch.Tensor:
"""FP4 反量化 (FP32 模拟值 -> FP32)"""
return fp4_data.float()
def compute_e8m0_scale_block(
block: torch.Tensor, mode: str = "baseline", emax_elem: int = FP4_E2M1_EMAX
) -> torch.Tensor:
"""
为一个 block 计算 E8M0 缩放因子 (2 的整数次幂)。
按 OCP MX 规范:
scale_exp = floor(log2(max_abs(block))) - emax_elem (baseline)
scale_exp = ceil(log2(max_abs(block))) - emax_elem (oas, 防饱和)
参数:
block: shape [..., block_size] 的张量
mode: "baseline" - floor, 最大化 FP4 范围利用, 允许少量饱和
"oas" - ceil, 保证不饱和, 牺牲一点精度
emax_elem: 元素格式最大正规化指数, FP4 E2M1 为 2
返回:
scale: 标量, 2 的整数次幂, 指数范围 [-127, 127]
"""
max_abs = torch.max(torch.abs(block))
if max_abs.item() < 1e-30:
return torch.tensor(1.0, device=block.device)
log_max = torch.log2(max_abs)
if mode == "baseline":
exponent = torch.floor(log_max) - emax_elem
elif mode == "oas":
exponent = torch.ceil(log_max) - emax_elem
else:
raise ValueError(f"Unsupported mode: {mode}")
exponent = torch.clamp(exponent, -127.0, 127.0)
return 2.0**exponent
def mx_quantize(
tensor: torch.Tensor, block_size: int = 32, mode: str = "baseline"
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
MXFP4 量化 (FP32 模拟版本, 输出仍为 FP32)。
将输入张量沿最后一维分组, 每组 block_size 个元素共享一个 E8M0 缩放因子,
再各自独立量化至 FP4 (E2M1)。
参数:
tensor: 输入张量, 任意形状
block_size: 分组大小, MX 规范固定为 32
mode: 缩放因子计算策略 ("baseline" 或 "oas")
返回:
fp4_vals: 量化后的 FP32 模拟值, 形状同输入
e8m0_scales: 每组共享的缩放因子, shape=[num_blocks]
"""
if tensor.dim() == 0:
tensor = tensor.unsqueeze(0)
tensor_last = tensor.shape[-1]
pad_size = 0
if tensor_last % block_size != 0:
pad_size = block_size - (tensor_last % block_size)
tensor = torch.nn.functional.pad(tensor, (0, pad_size), mode='constant', value=0.0)
reshaped = tensor.reshape(-1, block_size)
num_blocks = reshaped.shape[0]
scales = torch.empty(num_blocks, device=tensor.device)
fp4_quant = torch.empty_like(reshaped)
for i in range(num_blocks):
block = reshaped[i]
scale = compute_e8m0_scale_block(block, mode=mode)
scales[i] = scale
normalized = block / scale
fp4_quant[i] = quantize_fp4(normalized)
fp4_quant = fp4_quant.view(tensor.shape)
if pad_size > 0:
fp4_quant = fp4_quant[..., :-pad_size]
return fp4_quant, scales
def mx_dequantize(fp4_vals: torch.Tensor, scales: torch.Tensor, block_size: int = 32) -> torch.Tensor:
"""
MXFP4 反量化 (从 FP32 模拟值恢复)。
参数:
fp4_vals: 量化后的 FP32 模拟值
scales: 缩放因子, shape=[num_blocks]
block_size: 分组大小
返回:
反量化后的 FP32 张量
"""
original_shape = fp4_vals.shape
last_dim = original_shape[-1] if fp4_vals.dim() > 0 else 1
pad_size = 0
if fp4_vals.dim() > 0 and last_dim % block_size != 0:
pad_size = block_size - (last_dim % block_size)
fp4_vals = torch.nn.functional.pad(fp4_vals, (0, pad_size), value=0.0)
reshaped = fp4_vals.reshape(-1, block_size)
dequant = reshaped * scales.unsqueeze(1).to(reshaped.device)
dequant = dequant.view(fp4_vals.shape)
if pad_size > 0:
dequant = dequant[..., :last_dim]
return dequant
def convert_fp32_simulated_to_fp4_encoding(fp32_simulated: torch.Tensor) -> torch.Tensor:
"""
将量化后的 FP32 模拟值转换为真实的 4-bit 二进制编码 (0-15)。
输入: FP32 张量 (值必须是 FP4 可表示的, 如 1.0, 0.5, -0.5 等)
输出: uint8 张量, 每个元素存储 0-15 的 4-bit 编码
注: E2M1 没有 inf/NaN, 因此无需特殊值处理。
"""
flat = fp32_simulated.flatten().cpu()
encodings = torch.zeros_like(flat, dtype=torch.uint8)
for fp32_val, encoding in FP32_TO_FP4.items():
mask = flat == fp32_val
encodings[mask] = encoding
return encodings.to(fp32_simulated.device).reshape(fp32_simulated.shape)
def pack_fp4_to_uint8(fp4_encodings: torch.Tensor) -> torch.Tensor:
"""
将 4-bit FP4 编码打包成 uint8 (每字节存储两个 FP4 值)。
低 4 位存前者, 高 4 位存后者 (小端序)。
输入: [..., block_size] uint8 张量 (每个元素 0-15)
输出: [..., block_size//2] uint8 张量
"""
shape = fp4_encodings.shape
if shape[-1] % 2 != 0:
raise ValueError(f"最后一维必须是偶数, 当前为 {shape[-1]}")
reshaped = fp4_encodings.reshape(-1, 2)
packed = (reshaped[:, 1] << 4) | (reshaped[:, 0] & 0x0F)
return packed.reshape(*shape[:-1], shape[-1] // 2)
def pack_mxfp4_full(fp4_simulated: torch.Tensor, scales: torch.Tensor, block_size: int = 32) -> Dict:
"""
完整的 MXFP4 打包: 将 FP32 模拟值和 scale 打包为紧凑字节格式。
参数:
fp4_simulated: mx_quantize() 输出的 FP32 模拟值
scales: mx_quantize() 输出的缩放因子 (2^exp)
block_size: 分组大小
返回打包字典:
'fp4_data': 打包后的 uint8 数据
'scales': E8M0 biased 指数 (uint8, 已加 bias=127)
'block_size': int
'original_shape': tuple (未填充的原始形状)
"""
fp4_encodings = convert_fp32_simulated_to_fp4_encoding(fp4_simulated)
last_dim = fp4_encodings.shape[-1]
if last_dim % block_size != 0:
pad_size = block_size - (last_dim % block_size)
fp4_encodings = torch.nn.functional.pad(fp4_encodings, (0, pad_size), mode='constant', value=0)
flat_encodings = fp4_encodings.reshape(-1, block_size)
packed_data = pack_fp4_to_uint8(flat_encodings)
scale_exps_int = torch.log2(scales).round().long()
scale_exponents = (scale_exps_int + E8M0_BIAS).clamp(0, 254).to(torch.uint8)
return {
'fp4_data': packed_data,
'scales': scale_exponents,
'block_size': block_size,
'original_shape': tuple(fp4_simulated.shape),
}
def unpack_uint8_to_fp4(packed: torch.Tensor) -> torch.Tensor:
"""
从打包的 uint8 恢复 4-bit FP4 编码。
输入: [..., packed_size] uint8 张量
输出: [..., packed_size*2] uint8 张量 (每个元素 0-15)
"""
shape = packed.shape
flat = packed.flatten()
unpacked = torch.zeros(len(flat) * 2, dtype=torch.uint8, device=packed.device)
unpacked[0::2] = flat & 0x0F
unpacked[1::2] = (flat >> 4) & 0x0F
return unpacked.reshape(*shape[:-1], shape[-1] * 2)
def unpack_mxfp4_full(packed_dict: Dict) -> Tuple[torch.Tensor, torch.Tensor]:
"""
从打包字典解包, 恢复为 FP32 模拟值 ([num_blocks, block_size]) 和 scales。
返回: (fp32_values_blocks, scales)
fp32_values_blocks: shape [num_blocks, block_size], 尚未乘以 scale
scales: shape [num_blocks], 真实 scale 值 (2^exp)
"""
fp4_encodings = unpack_uint8_to_fp4(packed_dict['fp4_data'])
fp32_vals = torch.zeros_like(fp4_encodings, dtype=torch.float32)
for encoding, fp32_val in FP4_TO_FP32.items():
mask = fp4_encodings == encoding
fp32_vals[mask] = fp32_val
scales = 2.0 ** (packed_dict['scales'].float() - E8M0_BIAS)
return fp32_vals, scales
def mx_quantize_fp4_full(tensor: torch.Tensor, block_size: int = 32, mode: str = "baseline") -> Dict:
"""
端到端 MXFP4 量化: FP32 输入 -> 打包的紧凑字节格式。
参数:
tensor: 输入 FP32 张量
block_size: 分组大小 (默认 32)
mode: 缩放策略 ("baseline" 或 "oas")
返回:
打包字典 (可直接存储或传输)
"""
fp4_simulated, scales = mx_quantize(tensor, block_size, mode)
return pack_mxfp4_full(fp4_simulated, scales, block_size)
def mx_dequantize_fp4_full(packed_dict: Dict) -> torch.Tensor:
"""
端到端 MXFP4 反量化: 打包字节格式 -> FP32 张量。
正确处理 last dim 非 block_size 倍数的情况:
1. 先按 [num_blocks, block_size] 形状乘以 scale
2. reshape 回 padded 形状
3. 沿最后一维裁剪到 original_shape
参数:
packed_dict: mx_quantize_fp4_full() 返回的字典
返回:
反量化后的 FP32 张量, 形状与原输入一致
"""
fp32_blocks, scales = unpack_mxfp4_full(packed_dict)
block_size = packed_dict['block_size']
original_shape = packed_dict['original_shape']
dequant_flat = fp32_blocks * scales.unsqueeze(1).to(fp32_blocks.device)
if len(original_shape) == 0:
return dequant_flat.flatten()[0]
last_dim = original_shape[-1]
if last_dim % block_size != 0:
pad_size = block_size - (last_dim % block_size)
else:
pad_size = 0
padded_last = last_dim + pad_size
padded_shape = tuple(original_shape[:-1]) + (padded_last,)
dequant = dequant_flat.reshape(padded_shape)
if pad_size > 0:
dequant = dequant[..., :last_dim]
return dequant
def mxfp4_quantize_pack_last(
tensor: torch.Tensor, quant_axis: int, block_size: int = 32, mode: str = "baseline"
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
打包轴固定为最后一维, 量化轴任意 (可与打包轴同轴).
参数:
tensor: 任意形状 FP32 张量, 最后一维必须为偶数 (打包要求)
quant_axis: 量化轴 (支持负索引). 若 == 最后一维, 行为等价于标准 OCP MXFP4.
block_size: MX 分块大小 (规范固定 32)
mode: "baseline" (floor) 或 "oas" (ceil)
返回:
packed_fp4: uint8, shape = 输入 shape, 但 last 维 = last // 2
OCP MX nibble 序: low = elem[2k], high = elem[2k+1]
e8m0_scales: uint8 (biased, +127), shape = 输入 shape, 但 quant_axis 维
= ceil(quant_dim / block_size)
"""
nd = tensor.dim()
qax = quant_axis % nd
last_dim = tensor.shape[-1]
quant_dim = tensor.shape[qax]
assert last_dim % 2 == 0, f"最后一维需为偶数 (打包要求), 当前 {last_dim}"
t_q_last = tensor.movedim(qax, -1).contiguous()
fp4_sim, scales_flat = mx_quantize(t_q_last, block_size=block_size, mode=mode)
fp4_codes_q_last = convert_fp32_simulated_to_fp4_encoding(fp4_sim)
fp4_codes = fp4_codes_q_last.movedim(-1, qax).contiguous()
packed_fp4 = pack_fp4_to_uint8(fp4_codes)
n_blk_q = (quant_dim + block_size - 1) // block_size
rest_q_shape = list(t_q_last.shape[:-1])
scales_q_last = scales_flat.reshape(*rest_q_shape, n_blk_q)
scales_real = scales_q_last.movedim(-1, qax).contiguous()
scale_exp_int = torch.log2(scales_real).round().long()
e8m0_scales = (scale_exp_int + E8M0_BIAS).clamp(0, 254).to(torch.uint8)
return packed_fp4, e8m0_scales
def linear_mxfp4(x: torch.Tensor, w: torch.Tensor, block_size: int = 32, mode: str = "baseline") -> torch.Tensor:
"""
使用 MXFP4 量化权重进行线性层推理 (模拟)。
参数:
x: 输入激活值 (FP32)
w: 权重 (FP32)
block_size: 分组大小
mode: 量化策略
返回:
输出 = x @ dequantized(w)
"""
qw, sw = mx_quantize(w, block_size, mode)
with torch.no_grad():
dqw = mx_dequantize(qw, sw, block_size)
return x @ dqw
if __name__ == "__main__":
print("=" * 60)
print("MXFP4 (E2M1) 完整量化与反量化测试")
print("=" * 60)
def _check(name, a, b, atol=1e-6, rtol=1e-5):
"""对不变式做 allclose 断言, 打印 max diff."""
diff = (a - b).abs().max().item() if a.numel() > 0 else 0.0
ok = torch.allclose(a, b, atol=atol, rtol=rtol)
print(f" [{'OK ' if ok else 'FAIL'}] {name}: max diff {diff:.3e}")
assert ok, f"{name}: max diff {diff:.3e} > atol={atol:.0e}"
def _check_true(name, cond):
ok = bool(cond)
print(f" [{'OK ' if ok else 'FAIL'}] {name}: {cond}")
assert ok, f"{name} failed"
def _info(name, value):
"""仅打印, 不断言 (量化误差/描述性指标)."""
print(f" [info] {name}: {value}")
torch.manual_seed(42)
w = torch.randn(128, 256, dtype=torch.float32)
print("\n--- 测试 1: 基础量化 (FP32 模拟) ---")
q, s = mx_quantize(w, block_size=32, mode="baseline")
dq = mx_dequantize(q, s, block_size=32)
error = torch.abs(w - dq).mean().item()
max_error = torch.abs(w - dq).max().item()
cos_sim = torch.nn.functional.cosine_similarity(w.flatten(), dq.flatten(), dim=0)
_info("Baseline 平均绝对误差", f"{error:.6f}")
_info("Baseline 最大绝对误差", f"{max_error:.6f}")
_info("Baseline 余弦相似度", f"{cos_sim.item():.6f}")
q_oas, s_oas = mx_quantize(w, block_size=32, mode="oas")
dq_oas = mx_dequantize(q_oas, s_oas, block_size=32)
error_oas = torch.abs(w - dq_oas).mean().item()
_info("OAS 平均绝对误差", f"{error_oas:.6f}")
print("\n--- 测试 2: 线性层推理模拟 ---")
x = torch.randn(16, 128, dtype=torch.float32)
output_fp32 = x @ w
output_mx = linear_mxfp4(x, w, mode="baseline")
output_mx_oas = linear_mxfp4(x, w, mode="oas")
cos_sim_out = torch.nn.functional.cosine_similarity(output_fp32.flatten(), output_mx.flatten(), dim=0)
cos_sim_out_oas = torch.nn.functional.cosine_similarity(output_fp32.flatten(), output_mx_oas.flatten(), dim=0)
_info("Baseline (推理) 余弦相似度", f"{cos_sim_out.item():.6f}")
_info("OAS (推理) 余弦相似度", f"{cos_sim_out_oas.item():.6f}")
print("\n--- 测试 3: 完整打包流程 (FP32 -> FP4 字节) ---")
packed = mx_quantize_fp4_full(w, block_size=32, mode="baseline")
original_bytes = w.numel() * 4
actual_compressed = w.numel() * 0.5 + (w.numel() // 32)
print(f" 原始大小 (FP32): {original_bytes:,} bytes")
print(f" 压缩后大小 (MXFP4): {actual_compressed:,.0f} bytes")
print(f" 压缩比: {actual_compressed / original_bytes:.2%}")
print(f" 打包数据形状: {packed['fp4_data'].shape}")
print(f" Scale 形状: {packed['scales'].shape}")
print(f" 每 block 字节数: {32 * 0.5 + 1:.1f} (16 FP4 + 1 scale)")
dq_packed = mx_dequantize_fp4_full(packed)
error_packed = torch.abs(w - dq_packed).mean().item()
_info("打包后恢复误差 (量化噪声)", f"{error_packed:.6f}")
_check("打包路径 ≡ 模拟路径", dq, dq_packed, atol=0.0, rtol=0.0)
print("\n--- 测试 4: FP4 编码表验证 ---")
print(f" FP4 可表示值数量: {len(FP4_VALUES)}")
print(f" FP32->FP4 映射数量: {len(FP32_TO_FP4)}")
print(f" FP4->FP32 映射数量: {len(FP4_TO_FP32)}")
has_special = any((v != v) or (v == float('inf')) or (v == float('-inf')) for v in FP4_TO_FP32.values())
_check_true("编码表不含 inf/NaN", not has_special)
test_tensor = torch.tensor([0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -6.0, 4.0])
enc = convert_fp32_simulated_to_fp4_encoding(test_tensor)
print(f"\n 测试值: {test_tensor.tolist()}")
print(f" 编码: {[bin(e.item())[2:].zfill(4) for e in enc]}")
test_vals = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0])
enc_all = convert_fp32_simulated_to_fp4_encoding(test_vals)
recovered = torch.zeros_like(test_vals)
for i, e in enumerate(enc_all):
recovered[i] = FP4_TO_FP32[e.item()]
_check("全 16 值往返", test_vals, recovered, atol=0.0, rtol=0.0)
_info("恢复值", f"{recovered.tolist()}")
print("\n--- 测试 5: 不同形状张量支持 ---")
shapes = [(32,), (64, 32), (16, 128, 64), (3, 50), (7, 33)]
for shape in shapes:
t = torch.randn(shape)
q, s = mx_quantize(t, block_size=32)
dq = mx_dequantize(q, s, block_size=32)
err_sim = torch.abs(t - dq).mean().item()
packed = mx_quantize_fp4_full(t, block_size=32)
dq_pk = mx_dequantize_fp4_full(packed)
err_pk = torch.abs(t - dq_pk).mean().item()
_check(f"形状 {list(shape)} 打包路径 ≡ 模拟路径", dq, dq_pk, atol=0.0, rtol=0.0)
_info(f"形状 {list(shape)} 量化误差", f"sim={err_sim:.6f}, pack={err_pk:.6f}")
print("\n--- 测试 6: 小值张量 (验证负指数支持) ---")
w_small = torch.randn(64, 64) * 0.01
q_s, s_s = mx_quantize(w_small, block_size=32)
dq_s = mx_dequantize(q_s, s_s, block_size=32)
err_small = torch.abs(w_small - dq_s).mean().item()
rel_err = err_small / w_small.abs().mean().item()
_info("输入 abs 均值", f"{w_small.abs().mean().item():.6f}")
_info("反量化绝对误差", f"{err_small:.6f}")
_info("相对误差", f"{rel_err:.2%}")
print("\n--- 测试 7: mxfp4_quantize_pack_last ---")
def _dequant_from_packed(packed_fp4, e8m0_scales, quant_axis, orig_shape, block_size=32):
"""unpack + nibble->FP4 codes -> FP32 -> 沿 quant_axis 乘 scale, 还原 layout."""
nd = len(orig_shape)
qax = quant_axis % nd
codes = unpack_uint8_to_fp4(packed_fp4)
fp4_sim = torch.zeros_like(codes, dtype=torch.float32)
for enc, val in FP4_TO_FP32.items():
fp4_sim[codes == enc] = val
scales_real = 2.0 ** (e8m0_scales.float() - E8M0_BIAS)
fp4_sim_q_last = fp4_sim.movedim(qax, -1).contiguous()
scales_q_last = scales_real.movedim(qax, -1).contiguous()
quant_dim = orig_shape[qax]
pad_size = (block_size - quant_dim % block_size) % block_size
if pad_size > 0:
fp4_sim_q_last = torch.nn.functional.pad(fp4_sim_q_last, (0, pad_size))
scales_expanded = scales_q_last.repeat_interleave(block_size, dim=-1)
dequant_padded = fp4_sim_q_last * scales_expanded
if pad_size > 0:
dequant_padded = dequant_padded[..., :quant_dim]
return dequant_padded.movedim(-1, qax).contiguous()
def _ref_quant_dequant_along(t, qax, block_size=32, mode="baseline"):
"""参考路径: movedim 量化轴到 last, 标准 MXFP4, 搬回."""
t_q_last = t.movedim(qax, -1).contiguous()
ref_q, ref_s = mx_quantize(t_q_last, block_size=block_size, mode=mode)
ref_dq = mx_dequantize(ref_q, ref_s, block_size=block_size)
return ref_dq.movedim(-1, qax).contiguous()
t1 = torch.randn(4, 8, 32, 64)
packed1, scales1 = mxfp4_quantize_pack_last(t1, quant_axis=-1)
print(f" [shape] 同轴 (qax=-1): in={list(t1.shape)} packed={list(packed1.shape)} scales={list(scales1.shape)}")
ref1 = _ref_quant_dequant_along(t1, -1)
new1 = _dequant_from_packed(packed1, scales1, -1, t1.shape)
_check("同轴 (qax=-1) 反量化 ≡ 参考路径", ref1, new1, atol=0.0, rtol=0.0)
t2 = torch.randn(4, 8, 33, 64)
packed2, scales2 = mxfp4_quantize_pack_last(t2, quant_axis=2)
print(
f" [shape] 异轴 (qax=2, S=33): in={list(t2.shape)} packed={list(packed2.shape)} scales={list(scales2.shape)}"
)
ref2 = _ref_quant_dequant_along(t2, 2)
new2 = _dequant_from_packed(packed2, scales2, 2, t2.shape)
_check("异轴 (qax=2) 反量化 ≡ movedim 参考路径", ref2, new2, atol=0.0, rtol=0.0)
t3 = torch.randn(64, 16, 32)
packed3, scales3 = mxfp4_quantize_pack_last(t3, quant_axis=0)
print(f" [shape] 异轴 (qax=0): in={list(t3.shape)} packed={list(packed3.shape)} scales={list(scales3.shape)}")
ref3 = _ref_quant_dequant_along(t3, 0)
new3 = _dequant_from_packed(packed3, scales3, 0, t3.shape)
_check("异轴 (qax=0) 反量化 ≡ movedim 参考路径", ref3, new3, atol=0.0, rtol=0.0)
t4 = torch.randn(2, 100, 16) * 0.5
packed4, scales4 = mxfp4_quantize_pack_last(t4, quant_axis=1, mode="oas")
print(f" [shape] OAS 异轴 (qax=1): in={list(t4.shape)} packed={list(packed4.shape)} scales={list(scales4.shape)}")
ref4 = _ref_quant_dequant_along(t4, 1, mode="oas")
new4 = _dequant_from_packed(packed4, scales4, 1, t4.shape)
_check("OAS 异轴 (qax=1) 反量化 ≡ movedim 参考路径", ref4, new4, atol=0.0, rtol=0.0)
print("\n" + "=" * 60)
print("测试完成!")
print("=" * 60)