import io
import math
from collections.abc import Iterable
from typing import Any, Optional

import pytest
import torch

from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE, TensorUsage
from transformer_engine.pytorch.quantization import (
    is_fp8_available,
    is_fp8_block_scaling_available,
    is_mxfp8_available,
)
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.tensor import (
    Float8BlockQuantizer,
    Float8BlockwiseQTensor,
    Float8CurrentScalingQuantizer,
    Float8Quantizer,
    Float8Tensor,
    MXFP8Quantizer,
    MXFP8Tensor,
)


_dense_dtypes = [torch.float32, torch.float16, torch.bfloat16]
_dtypes = [torch.float16, torch.bfloat16]
_fp8_dtypes = [torch.float8_e4m3fn, torch.float8_e5m2]
_float8_tols = {
    torch.float8_e4m3fn: dict(rtol=0.125, atol=0.0675),
    torch.float8_e5m2: dict(rtol=0.25, atol=0.125),
}

# NVIDIA CUDA tests use per-element _tols in tests/pytorch/test_quantized_tensor.py.
# The current NPU backend does not yet match those strict bounds for MXFP8
# dequantize/chunk numerics, so we keep a separate acceptance contract instead of
# overloading the NVIDIA name.
_npu_mxfp8_acceptance_bounds = {
    torch.float8_e4m3fn: dict(max_mae=0.1, max_abs=1.0),
    torch.float8_e5m2: dict(max_mae=0.2, max_abs=1.0),
}

fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
    is_fp8_block_scaling_available(return_reason=True)
)
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
_quantization_list = []
if fp8_available:
    _quantization_list.append("fp8")
if fp8_block_scaling_available:
    _quantization_list.append("fp8_blockwise")
if mxfp8_available:
    _quantization_list.append("mxfp8")


def _to_list(x: Iterable | Any) -> list:
    """Convert to list if iterable, otherwise put in singleton list."""
    if isinstance(x, Iterable):
        return list(x)
    return [x]


def _npu_device() -> torch.device:
    return torch.device("npu")


def _rand_npu(shape: int | Iterable[int], *, dtype: torch.dtype) -> torch.Tensor:
    """Generate random test data without exercising NPU random kernels."""
    return torch.rand(shape, dtype=dtype, device="cpu").to(device=_npu_device())


def _uniform_npu(shape: int | Iterable[int], *, dtype: torch.dtype) -> torch.Tensor:
    """Generate values in [-1, 1] without exercising NPU random kernels."""
    return 2 * _rand_npu(shape, dtype=dtype) - 1


def _npu_float8_raw_shape(shape: Iterable[int]) -> torch.Size:
    """Return NPU Float8 raw storage shape used by quant matmul kernels."""
    shape = tuple(shape)
    if len(shape) == 0:
        return torch.Size((1, 1))
    if len(shape) == 2:
        return torch.Size(shape)
    return torch.Size((math.prod(shape[:-1]), shape[-1]))


def _logical_size(tensor: torch.Tensor) -> torch.Size:
    """Return the NVIDIA-facing logical size for an NPU quantized tensor."""
    if isinstance(tensor, QuantizedTensor) and hasattr(tensor, "_shape"):
        return torch.Size(tensor.origin_shape)
    return tensor.size()


def _assert_float8_raw_storage(tensor: Float8Tensor) -> None:
    """Check NPU Float8 raw storage without relying on tensor public size."""
    expected_raw_shape = _npu_float8_raw_shape(_logical_size(tensor))
    assert tensor._data.size() == expected_raw_shape
    gemm_data, _ = tensor.get_data(TensorUsage.LHS)
    assert gemm_data.size() == expected_raw_shape


def _seed() -> None:
    seed = 1234
    torch.manual_seed(seed)
    if hasattr(torch, "npu"):
        torch.npu.manual_seed(seed)


def _make_float8_quantizer(
    *,
    fp8_dtype: torch.dtype = torch.float8_e4m3fn,
    scale: float = 1.0,
    device: torch.device | None = None,
    rowwise: bool = True,
    columnwise: bool = True,
) -> Float8Quantizer:
    if device is None:
        device = _npu_device()
    return Float8Quantizer(
        scale=torch.full((1,), scale, dtype=torch.float32, device=device),
        amax=torch.zeros((1,), dtype=torch.float32, device=device),
        fp8_dtype=fp8_dtype,
        rowwise=rowwise,
        columnwise=columnwise,
    )


def _to_float8(
    tensor: torch.Tensor,
    *,
    fp8_dtype: torch.dtype = torch.float8_e4m3fn,
    scale: float = 1.0,
) -> Float8Tensor:
    tensor = tensor.to(device=_npu_device())
    return _make_float8_quantizer(
        fp8_dtype=fp8_dtype,
        scale=scale,
        device=tensor.device,
    )(tensor)


def _to_float8_current_scaling(
    tensor: torch.Tensor,
    *,
    fp8_dtype: torch.dtype = torch.float8_e4m3fn,
    return_transpose: bool = False,
    force_pow_2_scales: bool = False,
    amax_epsilon: float = 0.0,
) -> Float8Tensor:
    tensor = tensor.to(device=_npu_device())
    quantizer = Float8CurrentScalingQuantizer(
        fp8_dtype=fp8_dtype,
        device=tensor.device,
        force_pow_2_scales=force_pow_2_scales,
        amax_epsilon=amax_epsilon,
    )
    quantizer.set_usage(rowwise=True, columnwise=return_transpose)
    return quantizer(tensor)


def _to_float8_blockwise(
    tensor: torch.Tensor,
    *,
    fp8_dtype: torch.dtype = torch.float8_e4m3fn,
) -> Float8BlockwiseQTensor:
    tensor = tensor.to(device=_npu_device())
    quantizer = Float8BlockQuantizer(
        fp8_dtype=fp8_dtype,
        rowwise=True,
        columnwise=True,
        force_pow_2_scales=True,
        amax_epsilon=0.0,
        block_scaling_dim=1,
    )
    return quantizer(tensor)


@torch.no_grad()
def make_reference_and_test_tensors(
    shape: int | Iterable[int],
    quantization: Optional[str] = None,
    ref_dtype: torch.dtype = torch.float64,
    ref_device: torch.device | str = "cpu",
    test_dtype: torch.dtype = torch.bfloat16,
    test_device: torch.device | str = _npu_device(),
    requires_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Construct tensors with the same values.

    This mirrors the helper shape of NVIDIA's tests/pytorch/test_quantized_tensor.py,
    but only covers the quantization modes currently exercised in the NPU port.
    """

    ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
    test = ref.to(device=test_device, dtype=test_dtype)
    if quantization is None:
        if test.data_ptr() == ref.data_ptr():
            test = test.clone()
    elif quantization == "fp8":
        test = _to_float8(test, fp8_dtype=torch.float8_e4m3fn, scale=1.0)
    elif quantization == "fp8_blockwise":
        test = _to_float8_blockwise(test, fp8_dtype=torch.float8_e4m3fn)
    elif quantization == "mxfp8":
        test = MXFP8Quantizer(fp8_dtype=torch.float8_e4m3fn)(test)
    else:
        raise ValueError(f"Unsupported quantization scheme ({quantization})")

    if isinstance(test, QuantizedTensor):
        ref.copy_(test.dequantize().to(dtype=ref.dtype, device="cpu"))
    else:
        ref.copy_(test.to(dtype=ref.dtype, device="cpu"))

    ref.requires_grad_(requires_grad)
    test.requires_grad_(requires_grad)
    return ref, test


def _assert_npu_mxfp8_acceptance(
    actual: torch.Tensor,
    expected: torch.Tensor,
    *,
    max_mae: float,
    max_abs: float,
) -> None:
    actual = actual.detach().to(dtype=torch.float32, device="cpu")
    expected = expected.detach().to(dtype=torch.float32, device="cpu")
    diff = (actual - expected).abs()
    mae = diff.mean().item()
    max_err = diff.max().item()
    assert mae <= max_mae, f"mae too large: got {mae:.6f}, expected <= {max_mae:.6f}"
    assert max_err <= max_abs, f"max_abs too large: got {max_err:.6f}, expected <= {max_abs:.6f}"


def _decode_mxfp8_data_for_reference(data: torch.Tensor, fp8_dtype: torch.dtype) -> torch.Tensor:
    """Decode MXFP8 payload to float32 on CPU."""
    if data.dtype == fp8_dtype:
        decoded = data
    else:
        decoded = data.contiguous().view(dtype=fp8_dtype)
    return decoded.to(dtype=torch.float32, device="cpu")


def _columnwise_scalar_reference_from_logical_scale(
    data: torch.Tensor,
    logical_scale: torch.Tensor,
    *,
    fp8_dtype: torch.dtype,
) -> torch.Tensor:
    """Independent CPU scalar reference for 1x32 block dequantize."""
    q = _decode_mxfp8_data_for_reference(data, fp8_dtype)
    scale = logical_scale.detach().to(dtype=torch.float32, device="cpu")
    m_dim = math.prod(q.shape[:-1]) if q.ndim > 1 else 1
    k_dim = q.shape[-1] if q.ndim > 0 else 1
    q_2d = q.reshape(m_dim, k_dim)
    out = torch.empty((m_dim, k_dim), dtype=torch.float32, device="cpu")

    k_blocks = math.ceil(k_dim / MXFP8_BLOCK_SCALING_SIZE)
    assert tuple(scale.shape) == (m_dim, k_blocks)
    for m_idx in range(m_dim):
        for k_block in range(k_blocks):
            block_scale = float(scale[m_idx, k_block])
            start = k_block * MXFP8_BLOCK_SCALING_SIZE
            end = min(start + MXFP8_BLOCK_SCALING_SIZE, k_dim)
            for k_idx in range(start, end):
                out[m_idx, k_idx] = float(q_2d[m_idx, k_idx]) * block_scale
    return out.reshape(*q.shape)


def _rowwise_scalar_reference_from_logical_scale(
    data: torch.Tensor,
    logical_scale: torch.Tensor,
    *,
    fp8_dtype: torch.dtype,
) -> torch.Tensor:
    """Independent CPU scalar reference for 32x1 block dequantize."""
    q = _decode_mxfp8_data_for_reference(data, fp8_dtype)
    scale = logical_scale.detach().to(dtype=torch.float32, device="cpu")
    m_dim = math.prod(q.shape[:-1]) if q.ndim > 1 else 1
    k_dim = q.shape[-1] if q.ndim > 0 else 1
    q_2d = q.reshape(m_dim, k_dim)
    out = torch.empty((m_dim, k_dim), dtype=torch.float32, device="cpu")

    m_blocks = math.ceil(m_dim / MXFP8_BLOCK_SCALING_SIZE)
    assert tuple(scale.shape) == (m_blocks, k_dim)
    for m_block in range(m_blocks):
        start = m_block * MXFP8_BLOCK_SCALING_SIZE
        end = min(start + MXFP8_BLOCK_SCALING_SIZE, m_dim)
        for k_idx in range(k_dim):
            block_scale = float(scale[m_block, k_idx])
            for m_idx in range(start, end):
                out[m_idx, k_idx] = float(q_2d[m_idx, k_idx]) * block_scale
    return out.reshape(*q.shape)


def _make_manual_mxfp8_tensor(
    shape: Iterable[int],
    *,
    dtype: torch.dtype,
    fp8_dtype: torch.dtype,
    rowwise_data: Optional[torch.Tensor],
    rowwise_scale_inv: Optional[torch.Tensor],
    columnwise_data: Optional[torch.Tensor],
    columnwise_scale_inv: Optional[torch.Tensor],
) -> MXFP8Tensor:
    """Construct MXFP8Tensor directly from supplied raw buffers."""
    return MXFP8Tensor(
        shape=tuple(shape),
        dtype=dtype,
        rowwise_data=rowwise_data,
        rowwise_scale_inv=rowwise_scale_inv,
        columnwise_data=columnwise_data,
        columnwise_scale_inv=columnwise_scale_inv,
        fp8_dtype=fp8_dtype,
        quantizer=None,
        with_gemm_swizzled_scales=False,
        device=_npu_device(),
    )


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor:
    @staticmethod
    def setup_class() -> None:
        _seed()

    @staticmethod
    def _test_quantize_dequantize(
        *,
        fp8_dtype: torch.dtype = torch.float8_e4m3fn,
        scale: float = 3.5,
        dtype: torch.dtype = torch.float32,
        dims: int | Iterable[int] = 23,
    ) -> None:
        x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1
        x_fp8 = _to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
        x_dequantized = x_fp8.dequantize(dtype=dtype).to(device="cpu")

        torch.testing.assert_close(x_dequantized, x_ref, **_float8_tols[fp8_dtype])
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_dequantized, -x_ref, **_float8_tols[fp8_dtype])

    def test_constructor(
        self,
        dims: int | Iterable[int] = 1,
        fp8_dtype: torch.dtype = torch.float8_e4m3fn,
        scale_inv: float = 0.375,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Adapted from NVIDIA TestFloat8Tensor.test_constructor."""
        dims = _to_list(dims)
        tensor = Float8Tensor(
            shape=dims,
            dtype=dtype,
            data=torch.zeros(dims, device=_npu_device(), dtype=fp8_dtype),
            fp8_scale_inv=torch.full((1,), scale_inv, dtype=torch.float32, device=_npu_device()),
            fp8_dtype=fp8_dtype,
            quantizer=_make_float8_quantizer(fp8_dtype=fp8_dtype, scale=scale_inv),
            device=_npu_device(),
        )
        assert list(tensor.size()) == dims
        assert tensor.dtype == dtype
        assert tensor.is_cuda

    @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
    @pytest.mark.parametrize("dtype", _dense_dtypes)
    def test_quantize_dequantize_dtypes(
        self,
        fp8_dtype: torch.dtype,
        dtype: torch.dtype,
    ) -> None:
        self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype)

    @pytest.mark.parametrize("scale", [0.375, 1.0, 3.5])
    def test_quantize_dequantize_scales(self, scale: float) -> None:
        self._test_quantize_dequantize(scale=scale)

    @pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]])
    def test_quantize_dequantize_dims(self, dims: int | Iterable[int]) -> None:
        self._test_quantize_dequantize(dims=dims)

    @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
    @pytest.mark.parametrize("dtype", _dense_dtypes)
    @pytest.mark.parametrize("noop", [True, False])
    def test_quantize_dequantize_noop(
        self,
        fp8_dtype: torch.dtype,
        dtype: torch.dtype,
        noop: bool,
    ) -> None:
        """Adapted from NVIDIA TestFloat8Tensor.test_quantize_dequantize_noop."""
        noop_tensor = torch.zeros(1, dtype=torch.float32, device=_npu_device())
        if noop:
            noop_tensor = torch.ones(1, dtype=torch.float32, device=_npu_device())
        dims = 23
        scale = 3.5

        x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1
        x_fp8 = _to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
        x_ref_noop_test = 2 * x_ref.to(device=_npu_device())
        x_fp8_orig = x_fp8.clone()

        x_fp8.quantize_(x_ref_noop_test, noop_flag=noop_tensor)
        if noop_tensor.item() == 1.0:
            torch.testing.assert_close(x_fp8, x_fp8_orig, atol=0, rtol=0)
        else:
            torch.testing.assert_close(
                x_fp8.dequantize(dtype=dtype),
                x_ref_noop_test,
                **_float8_tols[fp8_dtype],
            )

    def test_basic_ops(
        self,
        dims: int | Iterable[int] = 23,
        fp8_dtype: torch.dtype = torch.float8_e4m3fn,
        scale: float = 3.5,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Adapted from NVIDIA TestFloat8Tensor.test_basic_ops."""
        dims = _to_list(dims)
        x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
        y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
        x_fp8 = _to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
        y_fp8 = _to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale)
        x_ref = x_fp8.dequantize(dtype=dtype)
        y_ref = y_fp8.dequantize(dtype=dtype)

        torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0)
        torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0)

        tols = _float8_tols[fp8_dtype]
        torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols)
        torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols)
        torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols)
        torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols)
        torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols)
        torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols)

        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols)

    @pytest.mark.parametrize("dims", [2, [4, 4], [8, 5, 3, 3]])
    def test_chunk_op(
        self,
        dims: int | Iterable[int],
        fp8_dtype: torch.dtype = torch.float8_e4m3fn,
        scale: float = 1.0,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Adapted from NVIDIA TestFloat8Tensor.test_chunk_op."""
        dims = _to_list(dims)
        x_ref = torch.randn(dims, dtype=dtype, device="cpu")
        x_fp8 = _to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
        chunks = x_fp8.chunk(2, dim=0)
        ref_chunks = x_fp8.dequantize(dtype=dtype).cpu().chunk(2, dim=0)

        assert len(chunks) == len(ref_chunks)
        for chunk, ref_chunk in zip(chunks, ref_chunks):
            assert isinstance(chunk, Float8Tensor)
            assert _logical_size(chunk) == ref_chunk.shape
            torch.testing.assert_close(
                chunk.dequantize(dtype=dtype).cpu(),
                ref_chunk,
                rtol=0,
                atol=0,
            )

    def test_inplace_ops(
        self,
        dims: int | Iterable[int] = 23,
        fp8_dtype: torch.dtype = torch.float8_e4m3fn,
        scale: float = 3.5,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Adapted from NVIDIA TestFloat8Tensor.test_inplace_ops."""
        dims = _to_list(dims)
        x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
        y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
        x_fp8 = _to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
        y_fp8 = _to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale)
        x_ref = x_fp8.dequantize(dtype=dtype)
        y_ref = y_fp8.dequantize(dtype=dtype)

        tols = _float8_tols[fp8_dtype]
        x_fp8 += y_ref
        x_ref += y_ref
        torch.testing.assert_close(x_fp8.dequantize(dtype=dtype), x_ref, **tols)
        x_ref = x_fp8.dequantize(dtype=dtype)
        x_fp8 -= y_fp8
        x_ref -= y_fp8.dequantize(dtype=dtype)
        torch.testing.assert_close(x_fp8.dequantize(dtype=dtype), x_ref, **tols)
        x_ref = x_fp8.dequantize(dtype=dtype)
        x_fp8 *= 2
        x_ref *= 2
        torch.testing.assert_close(x_fp8.dequantize(dtype=dtype), x_ref, **tols)

        x_ref += 123
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_fp8.dequantize(dtype=dtype), x_ref, **tols)

    def test_serialization(
        self,
        dims: int | Iterable[int] = (2, 3, 5),
        fp8_dtype: torch.dtype = torch.float8_e4m3fn,
        scale: float = 0.5,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Adapted from NVIDIA TestFloat8Tensor.test_serialization."""
        dims = _to_list(dims)
        x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
        x_fp8 = _to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
        x_ref = x_fp8.dequantize(dtype=dtype).cpu()

        byte_stream = io.BytesIO()
        torch.save(x_fp8, byte_stream)
        x_bytes = byte_stream.getvalue()

        x_fp8._data.zero_()
        x_fp8._scale_inv.zero_()
        del x_fp8, byte_stream

        x_fp8 = torch.load(io.BytesIO(x_bytes), weights_only=False)
        del x_bytes
        torch.testing.assert_close(x_fp8.dequantize(dtype=dtype).cpu(), x_ref, rtol=0, atol=0)

        x_fp8._data.zero_()
        x_fp8._scale_inv.zero_()
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_fp8.dequantize(dtype=dtype).cpu(), x_ref, rtol=0, atol=0)

    def test_set_data(self) -> None:
        """Adapted from NVIDIA TestFloat8Tensor.test_set_data."""
        x0 = torch.zeros(4, dtype=torch.float32)
        x = _to_float8(x0)
        assert isinstance(x, Float8Tensor)
        assert x0.size() == _logical_size(x)
        _assert_float8_raw_storage(x)
        assert x.dtype == torch.float32
        assert x.is_cuda and x._data.is_npu
        y = x.dequantize()
        assert not isinstance(y, Float8Tensor)
        assert _logical_size(x) == y.size()
        assert x.dtype == y.dtype
        assert x.device == y.device

        x0 = torch.zeros((3, 2), dtype=torch.float16, device=x.device)
        x.data = x0
        assert isinstance(x, Float8Tensor)
        assert x0.size() == _logical_size(x)
        _assert_float8_raw_storage(x)
        assert x0.dtype == x.dtype
        assert x0.device == x.device == x._data.device
        y = x.dequantize()
        assert not isinstance(y, Float8Tensor)
        assert _logical_size(x) == y.size()
        assert x.dtype == y.dtype
        assert x.device == y.device

        x = _to_float8(torch.ones((4, 3, 1), dtype=torch.float32))
        x0 = _to_float8(torch.zeros((4, 3, 1), dtype=torch.float32))
        x.data = x0
        assert isinstance(x, Float8Tensor)
        assert _logical_size(x0) == _logical_size(x)
        _assert_float8_raw_storage(x)
        assert x0.dtype == x.dtype
        assert x0.device == x.device == x._data.device
        assert x0._data is x._data
        assert x0._scale_inv is x._scale_inv
        y = x.dequantize()
        assert not isinstance(y, Float8Tensor)
        assert _logical_size(x) == y.size()
        assert x.dtype == y.dtype
        assert x.device == y.device


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestCurrentScalingFloat8Tensor:
    @staticmethod
    def setup_class() -> None:
        _seed()

    @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize(
        "dims",
        [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3], [128, 128], [611, 782]],
    )
    @pytest.mark.parametrize("return_transpose", [True, False], ids=str)
    @pytest.mark.parametrize("force_pow_2_scales", [True, False], ids=str)
    @pytest.mark.parametrize("amax_epsilon", [0.0, 1e-6], ids=str)
    def test_quantize(
        self,
        fp8_dtype: torch.dtype,
        dtype: torch.dtype,
        dims: int | Iterable[int],
        return_transpose: bool,
        force_pow_2_scales: bool,
        amax_epsilon: float,
    ) -> None:
        """Adapted from NVIDIA TestCurrentScalingFloat8Tensor.test_quantize."""
        x_hp = _uniform_npu(_to_list(dims), dtype=dtype)
        x_fp8 = _to_float8_current_scaling(
            x_hp,
            fp8_dtype=fp8_dtype,
            return_transpose=return_transpose,
            force_pow_2_scales=force_pow_2_scales,
            amax_epsilon=amax_epsilon,
        )

        assert isinstance(x_fp8, Float8Tensor)
        assert x_fp8._data is not None
        assert x_fp8._data.dtype == fp8_dtype
        assert x_fp8._scale_inv.dtype == torch.float32
        assert x_fp8._quantizer is not None
        assert x_fp8._quantizer.get_usages() == {
            "rowwise": True,
            "columnwise": return_transpose,
        }
        torch.testing.assert_close(
            x_fp8.dequantize(dtype=dtype),
            x_hp,
            **_float8_tols[fp8_dtype],
        )

    @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
    @pytest.mark.parametrize("dtype", [torch.bfloat16])
    @pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]])
    def test_quantize_dequantize(
        self,
        fp8_dtype: torch.dtype,
        dtype: torch.dtype,
        dims: int | Iterable[int],
    ) -> None:
        """Adapted from NVIDIA TestCurrentScalingFloat8Tensor.test_quantize_dequantize."""
        x_hp = _uniform_npu(_to_list(dims), dtype=dtype)
        x_fp8 = _to_float8_current_scaling(x_hp, fp8_dtype=fp8_dtype)
        x_deq = x_fp8.dequantize(dtype=dtype)

        torch.testing.assert_close(x_deq, x_hp, **_float8_tols[fp8_dtype])
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_deq, x_hp + 1.0, **_float8_tols[fp8_dtype])


class TestQuantizedTensor:
    @staticmethod
    def setup_class() -> None:
        _seed()

    @pytest.mark.parametrize("op", ("clone", "view", "reshape", "contiguous"))
    @pytest.mark.parametrize("quantization", _quantization_list)
    @pytest.mark.parametrize("shape", [(128, 128), (1024, 2048)])
    def test_identity_op(
        self,
        *,
        op: str,
        quantization: str,
        shape: Iterable[int],
        dtype: torch.dtype = torch.bfloat16,
    ) -> None:
        """Adapted from NVIDIA TestQuantizedTensor.test_identity_op."""

        x_ref, x_test = make_reference_and_test_tensors(
            shape=shape,
            quantization=quantization,
            test_dtype=dtype,
        )
        dy_test = _rand_npu(shape, dtype=dtype)
        dy_ref = dy_test.to(dtype=torch.float64, device="cpu")

        if op == "clone":
            y_ref = x_ref.clone()
            y_test = x_test.clone()
        elif op == "view":
            y_ref = x_ref.view(shape)
            y_test = x_test.view(shape)
        elif op == "reshape":
            y_ref = x_ref.reshape(shape)
            y_test = x_test.reshape(shape)
        else:
            y_ref = x_ref.contiguous()
            y_test = x_test.contiguous()

        if isinstance(y_test, Float8BlockwiseQTensor) and y_test.size() != dy_test.size():
            y_test = y_test.dequantize(dtype=dtype)
            y_test = y_test.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0)
            return

        y_test.backward(dy_test)
        assert x_test.grad is not None

        if isinstance(y_test, QuantizedTensor):
            y_test = y_test.dequantize(dtype=dtype)
        y_test = y_test.to(dtype=torch.float64, device="cpu")
        dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")

        torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0)
        torch.testing.assert_close(dx_test, dy_ref, rtol=0, atol=0)

    @pytest.mark.parametrize("quantization", _quantization_list)
    @pytest.mark.parametrize("dim", [0, 1])
    @pytest.mark.parametrize("shape", [(128, 128), (1024, 2048)])
    def test_chunk(
        self,
        *,
        quantization: str,
        dim: int,
        shape: Iterable[int],
        chunks: int = 2,
        dtype: torch.dtype = torch.bfloat16,
    ) -> None:
        """Adapted from NVIDIA TestQuantizedTensor.test_chunk."""

        x_ref, x_test = make_reference_and_test_tensors(
            shape=shape,
            quantization=quantization,
            test_dtype=dtype,
            requires_grad=False,
        )

        ys_ref = torch.chunk(x_ref, chunks, dim=dim)
        ys_test = torch.chunk(x_test, chunks, dim=dim)

        for y_ref, y_test in zip(ys_ref, ys_test):
            assert y_ref.size() == _logical_size(y_test)

            if quantization == "fp8":
                assert isinstance(y_test, Float8Tensor)
                y_test = y_test.dequantize(dtype=dtype)
                y_test = y_test.to(dtype=torch.float64, device="cpu")
                torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0)
                continue

            if isinstance(y_test, QuantizedTensor):
                y_test = y_test.dequantize(dtype=dtype)
            y_test = y_test.to(dtype=torch.float64, device="cpu")
            torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0)


@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
class TestMXFP8Tensor:
    @staticmethod
    def setup_class() -> None:
        _seed()

    @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("dims", [[128, 128], [256, 256], [1024, 2048]])
    def test_mxfp8_dequantize_columnwise_only(
        self,
        fp8_dtype: torch.dtype,
        dtype: torch.dtype,
        dims,
    ) -> None:
        """Port of NVIDIA TestMXFP8Tensor.test_mxfp8_dequantize_columnwise_only."""

        x_ref = _uniform_npu(_to_list(dims), dtype=dtype)

        quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True)
        x_mxfp8 = quantizer(x_ref)

        x_deq_rowwise = x_mxfp8.dequantize(dtype=dtype)
        _assert_npu_mxfp8_acceptance(
            x_deq_rowwise,
            x_ref,
            **_npu_mxfp8_acceptance_bounds[fp8_dtype],
        )

        x_mxfp8.update_usage(rowwise_usage=False, columnwise_usage=True)
        assert x_mxfp8._rowwise_data is None
        assert x_mxfp8._columnwise_data is not None

        x_deq_columnwise = x_mxfp8.dequantize(dtype=dtype)
        _assert_npu_mxfp8_acceptance(
            x_deq_columnwise,
            x_ref,
            **_npu_mxfp8_acceptance_bounds[fp8_dtype],
        )

        with pytest.raises(AssertionError):
            _assert_npu_mxfp8_acceptance(
                x_deq_columnwise,
                -x_ref,
                **_npu_mxfp8_acceptance_bounds[fp8_dtype],
            )

    @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
    @pytest.mark.parametrize("dims", [[128, 128], [256, 256], [1024, 2048]])
    def test_mxfp8_dequantize_columnwise_only_quantized_separately(
        self,
        fp8_dtype: torch.dtype,
        dims,
    ) -> None:
        """Port of NVIDIA TestMXFP8Tensor.test_mxfp8_dequantize_columnwise_only_quantized_separately."""

        dtype = torch.bfloat16
        x_ref = _uniform_npu(_to_list(dims), dtype=dtype)

        quantizer = MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=False, columnwise=True)
        x_mxfp8 = quantizer(x_ref)
        assert x_mxfp8._rowwise_data is None
        assert x_mxfp8._columnwise_data is not None

        x_deq = x_mxfp8.dequantize(dtype=dtype)
        _assert_npu_mxfp8_acceptance(
            x_deq,
            x_ref,
            **_npu_mxfp8_acceptance_bounds[fp8_dtype],
        )

        with pytest.raises(AssertionError):
            _assert_npu_mxfp8_acceptance(
                x_deq,
                -x_ref,
                **_npu_mxfp8_acceptance_bounds[fp8_dtype],
            )

    @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
    def test_mxfp8_columnwise_manual_tail_matches_cpu_scalar_reference(
        self,
        fp8_dtype: torch.dtype,
    ) -> None:
        """Validate non-32-divisible columnwise dequantize against independent CPU scalar math."""

        shape = (2, 17, 65)
        dtype = torch.float32
        device = _npu_device()
        m_dim = math.prod(shape[:-1])
        k_dim = shape[-1]
        k_blocks = math.ceil(k_dim / MXFP8_BLOCK_SCALING_SIZE)

        dense = torch.linspace(
            -1.0,
            1.0,
            steps=math.prod(shape),
            dtype=torch.float32,
            device=device,
        ).reshape(shape)
        columnwise_data = dense.to(dtype=fp8_dtype)
        logical_scale = (
            torch.arange(m_dim * k_blocks, dtype=torch.float32, device=device)
            .reshape(m_dim, k_blocks)
            .div(32.0)
            .add(0.5)
        )

        x_mxfp8 = _make_manual_mxfp8_tensor(
            shape,
            dtype=dtype,
            fp8_dtype=fp8_dtype,
            rowwise_data=None,
            rowwise_scale_inv=None,
            columnwise_data=columnwise_data,
            columnwise_scale_inv=logical_scale,
        )

        x_deq = x_mxfp8.dequantize(dtype=torch.float32).to(device="cpu")
        x_ref = _columnwise_scalar_reference_from_logical_scale(
            columnwise_data,
            logical_scale,
            fp8_dtype=fp8_dtype,
        )
        torch.testing.assert_close(x_deq, x_ref, rtol=0, atol=0)

    @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
    def test_mxfp8_rowwise_manual_tail_matches_cpu_scalar_reference(
        self,
        fp8_dtype: torch.dtype,
    ) -> None:
        """Validate non-32-divisible rowwise dequantize against independent CPU scalar math."""

        shape = (2, 17, 65)
        dtype = torch.float32
        device = _npu_device()
        m_dim = math.prod(shape[:-1])
        k_dim = shape[-1]
        m_blocks = math.ceil(m_dim / MXFP8_BLOCK_SCALING_SIZE)

        dense = torch.linspace(
            -1.0,
            1.0,
            steps=math.prod(shape),
            dtype=torch.float32,
            device=device,
        ).reshape(shape)
        rowwise_data = dense.to(dtype=fp8_dtype)
        logical_scale = (
            torch.arange(m_blocks * k_dim, dtype=torch.float32, device=device)
            .reshape(m_blocks, k_dim)
            .div(64.0)
            .add(0.25)
        )

        x_mxfp8 = _make_manual_mxfp8_tensor(
            shape,
            dtype=dtype,
            fp8_dtype=fp8_dtype,
            rowwise_data=rowwise_data,
            rowwise_scale_inv=logical_scale,
            columnwise_data=None,
            columnwise_scale_inv=None,
        )

        x_deq = x_mxfp8.dequantize(dtype=torch.float32).to(device="cpu")
        x_ref = _rowwise_scalar_reference_from_logical_scale(
            rowwise_data,
            logical_scale,
            fp8_dtype=fp8_dtype,
        )
        torch.testing.assert_close(x_deq, x_ref, rtol=0, atol=0)