import math
import pytest
import torch
from transformer_engine.pytorch.constants import TensorUsage
from transformer_engine.pytorch.quantization import is_mxfp4_available
from transformer_engine.pytorch.tensor import MXFP4Quantizer, MXFP4Tensor
from utils import npu_available
_mxfp4_available, _reason_for_no_mxfp4 = is_mxfp4_available(return_reason=True)
pytestmark = pytest.mark.skipif(
not npu_available() or not _mxfp4_available,
reason=_reason_for_no_mxfp4 or "NPU device is required",
)
def _uniform_npu(shape: tuple[int, ...], *, dtype: torch.dtype) -> torch.Tensor:
values = torch.linspace(
-1.0,
1.0,
steps=math.prod(shape),
dtype=torch.float32,
device="cpu",
).reshape(shape)
return values.to(device=torch.device("npu"), dtype=dtype).contiguous()
def test_mxfp4_quantize_invokes_dynamic_mx_quant() -> None:
dense = _uniform_npu((64, 64), dtype=torch.bfloat16)
quantizer = MXFP4Quantizer(rowwise=True, columnwise=True, with_rht=True)
quantized = quantizer(dense)
torch.npu.synchronize()
assert isinstance(quantized, MXFP4Tensor)
assert tuple(quantized.shape) == tuple(dense.shape)
assert quantized.dtype == dense.dtype
assert quantized._rowwise_data is not None
assert quantized._rowwise_scale_inv is not None
assert quantized._columnwise_data is not None
assert quantized._columnwise_scale_inv is not None
assert quantized._rowwise_data.device.type == "npu"
assert quantized._columnwise_data.device.type == "npu"
def test_mxfp4_quantized_matmul_invokes_quant_matmul() -> None:
inp = _uniform_npu((64, 64), dtype=torch.bfloat16)
weight = _uniform_npu((96, 64), dtype=torch.bfloat16)
quantizer = MXFP4Quantizer(rowwise=True, columnwise=True, with_rht=False)
inp_mxfp4 = quantizer(inp)
weight_mxfp4 = quantizer(weight)
out = inp_mxfp4.matmul(
weight_mxfp4,
TensorUsage.LHS,
TensorUsage.RHS_TRANS,
torch.bfloat16,
)
torch.npu.synchronize()
assert tuple(out.shape) == (inp.shape[0], weight.shape[0])
assert out.dtype == torch.bfloat16
assert out.device.type == "npu"