# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2026, Huawei Technologies Co., Ltd. All rights reserved.
#
# See LICENSE for license information.

"""Module-level sanity coverage adapted from NVIDIA's PyTorch sanity tests.

These tests intentionally exercise user-facing PyTorch modules instead of
``transformer_engine_torch``/``tex`` single operators.  The NPU port keeps a few
backend-specific layout choices, so this file validates the high-level training
path: module construction, autocast recipe wiring, forward/backward execution,
gradient materialization, zero-token handling, and quantized parameter init.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Optional

import pytest
import torch

from transformer_engine.common import recipe
from transformer_engine.pytorch import (
    GroupedLinear,
    Linear,
    autocast,
    quantized_model_init,
)
from transformer_engine.pytorch.quantization import (
    FP8GlobalStateManager,
    is_fp8_available,
    is_mxfp8_available,
)
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import init_method_normal

from utils import npu_available


@dataclass(frozen=True)
class ModelConfig:
    max_seqlen_q: int
    batch_size: int
    hidden_size: int
    num_layers: int = 2


_MODEL_CONFIGS = {
    "small": ModelConfig(max_seqlen_q=32, batch_size=2, hidden_size=32),
    "weird": ModelConfig(max_seqlen_q=37, batch_size=3, hidden_size=23),
}

_DTYPES = [torch.float16, torch.bfloat16]
_GRAD_CASES = [
    pytest.param(False, False, id="all_grads"),
    pytest.param(True, False, id="skip_wgrad"),
    pytest.param(False, True, id="skip_dgrad"),
]

_fp8_available, _reason_for_no_fp8 = is_fp8_available(return_reason=True)
_mxfp8_available, _reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)


pytestmark = pytest.mark.skipif(not npu_available(), reason="NPU device is required")


def _npu_device() -> torch.device:
    return torch.device("npu")


def _randn_npu(shape, *, dtype: torch.dtype, requires_grad: bool = False) -> torch.Tensor:
    """Generate random test inputs without exercising NPU random kernels."""
    return (
        torch.randn(shape, dtype=dtype, device="cpu")
        .to(device=_npu_device())
        .requires_grad_(requires_grad)
    )


def _seed() -> None:
    seed = 1234
    torch.manual_seed(seed)
    torch.npu.manual_seed(seed)


def _sync() -> None:
    torch.npu.synchronize()


def _scaled_init_method_normal(
    sigma: float, num_layers: int
) -> Callable[[torch.Tensor], torch.Tensor]:
    """Match NVIDIA sanity test intent without depending on CUDA-only helpers."""

    return init_method_normal(sigma / (2.0 * num_layers) ** 0.5)


def _module_recipe_cases() -> list[pytest.ParameterSet]:
    cases = [
        pytest.param("small", None, id="small_no_fp8"),
        pytest.param("weird", None, id="weird_no_fp8"),
    ]
    if _fp8_available:
        cases.extend(
            [
                pytest.param(
                    "small",
                    recipe.DelayedScaling,
                    id="small_delayed",
                    marks=pytest.mark.skipif(
                        not _fp8_available,
                        reason=_reason_for_no_fp8,
                    ),
                ),
                pytest.param(
                    "small",
                    recipe.Float8CurrentScaling,
                    id="small_current",
                    marks=pytest.mark.skipif(
                        not _fp8_available,
                        reason=_reason_for_no_fp8,
                    ),
                ),
            ]
        )
    if _mxfp8_available:
        cases.append(
            pytest.param(
                "small",
                recipe.MXFP8BlockScaling,
                id="small_mxfp8",
                marks=pytest.mark.skipif(
                    not _mxfp8_available,
                    reason=_reason_for_no_mxfp8,
                ),
            )
        )
    return cases


def _make_recipe(recipe_cls: Optional[type[recipe.Recipe]]) -> Optional[recipe.Recipe]:
    return None if recipe_cls is None else recipe_cls()


def _is_fp8_supported(config: ModelConfig) -> bool:
    return config.max_seqlen_q * config.batch_size % 16 == 0 and config.hidden_size % 16 == 0


def _skip_unsupported_fp8_config(config: ModelConfig, fp8_recipe: Optional[recipe.Recipe]) -> None:
    if fp8_recipe is None:
        return
    if not _is_fp8_supported(config):
        pytest.skip("Model config does not support FP8")


def _disable_wgrads(module: torch.nn.Module) -> None:
    for param in module.parameters():
        param.requires_grad = False


@pytest.fixture(autouse=True)
def _reset_global_fp8_state():
    _seed()
    yield
    FP8GlobalStateManager.reset()


def _first_tensor(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
    return output[0] if isinstance(output, tuple) else output


def _assert_common_backward(
    module: torch.nn.Module,
    inp: torch.Tensor,
    *,
    skip_wgrad: bool,
    skip_dgrad: bool,
) -> None:
    if not skip_dgrad:
        assert inp.grad is not None, "Input gradient should be materialized."
    if not skip_wgrad:
        missing = [
            name
            for name, param in module.named_parameters()
            if param.requires_grad and param.grad is None
        ]
        assert not missing, f"Parameter gradients should be materialized: {missing}"


def _run_common_module_sanity(
    module: torch.nn.Module,
    *,
    dtype: torch.dtype,
    config: ModelConfig,
    fp8_recipe: Optional[recipe.Recipe],
    skip_wgrad: bool,
    skip_dgrad: bool,
    microbatching: bool,
) -> None:
    if skip_dgrad and skip_wgrad:
        pytest.skip("No gradient computation; skipping to avoid PyTorch RuntimeError.")
    _skip_unsupported_fp8_config(config, fp8_recipe)

    inp = _randn_npu(
        (config.max_seqlen_q, config.batch_size, config.hidden_size),
        dtype=dtype,
        requires_grad=not skip_dgrad,
    )

    if skip_wgrad:
        _disable_wgrads(module)

    with autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
        if microbatching:
            _ = module(inp, is_first_microbatch=True)
            out = module(inp, is_first_microbatch=False)
        else:
            out = module(inp)

    out = _first_tensor(out)
    assert out.shape[:-1] == inp.shape[:-1]
    loss = out.sum()
    loss.backward()
    _sync()

    _assert_common_backward(
        module,
        inp,
        skip_wgrad=skip_wgrad,
        skip_dgrad=skip_dgrad,
    )


def _make_linear(config: ModelConfig, dtype: torch.dtype) -> Linear:
    return Linear(
        config.hidden_size,
        config.hidden_size,
        init_method=_scaled_init_method_normal(0.023, config.num_layers),
        params_dtype=dtype,
        device="npu",
    )


@pytest.mark.parametrize("dtype", _DTYPES)
@pytest.mark.parametrize(("model", "fp8_recipe_cls"), _module_recipe_cases())
@pytest.mark.parametrize(("skip_wgrad", "skip_dgrad"), _GRAD_CASES)
@pytest.mark.parametrize("microbatching", [False, True])
def test_sanity_linear(
    dtype: torch.dtype,
    fp8_recipe_cls: Optional[type[recipe.Recipe]],
    model: str,
    skip_wgrad: bool,
    skip_dgrad: bool,
    microbatching: bool,
) -> None:
    """Adapted from NVIDIA test_sanity_linear."""

    config = _MODEL_CONFIGS[model]
    fp8_recipe = _make_recipe(fp8_recipe_cls)
    block = _make_linear(config, dtype)

    _run_common_module_sanity(
        block,
        dtype=dtype,
        config=config,
        fp8_recipe=fp8_recipe,
        skip_wgrad=skip_wgrad,
        skip_dgrad=skip_dgrad,
        microbatching=microbatching,
    )


@pytest.mark.parametrize("dtype", _DTYPES)
@pytest.mark.parametrize("parallel_mode", [None, "row"])
@pytest.mark.parametrize("return_bias", [False, True])
def test_sanity_linear_bias_path(
    dtype: torch.dtype,
    parallel_mode: Optional[str],
    return_bias: bool,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    """Check that Linear routes bias through GEMM or an unfused add like NVIDIA TE."""

    from transformer_engine.pytorch.module import linear as linear_module

    config = _MODEL_CONFIGS["small"]
    block = Linear(
        config.hidden_size,
        config.hidden_size,
        bias=True,
        return_bias=return_bias,
        parallel_mode=parallel_mode,
        params_dtype=dtype,
        device="npu",
    )
    assert block.apply_bias == (not return_bias)
    assert block.gemm_bias_unfused_add == (parallel_mode == "row" and not return_bias)

    seen_gemm_biases = []
    original_general_gemm = linear_module.general_gemm

    def _record_general_gemm(*args, **kwargs):
        seen_gemm_biases.append(kwargs.get("bias"))
        return original_general_gemm(*args, **kwargs)

    monkeypatch.setattr(linear_module, "general_gemm", _record_general_gemm)

    inp = _randn_npu(
        (config.max_seqlen_q, config.hidden_size),
        dtype=dtype,
        requires_grad=True,
    )
    result = block(inp)
    if return_bias:
        out, returned_bias = result
        assert returned_bias.shape == block.bias.shape
        assert returned_bias.dtype == dtype
    else:
        out = result

    expected = torch.matmul(inp, block.weight.t())
    if not return_bias:
        expected = expected + block.bias
    torch.testing.assert_close(out, expected, rtol=1e-2, atol=1e-2)

    assert seen_gemm_biases, "Linear forward should route through general_gemm."
    expect_gemm_bias = block.apply_bias and not block.gemm_bias_unfused_add
    assert (seen_gemm_biases[0] is not None) == expect_gemm_bias

    out.sum().backward()
    _sync()
    if return_bias:
        assert block.bias.grad is None
    else:
        assert block.bias.grad is not None


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", [32])
@pytest.mark.parametrize(
    "fp8_recipe_cls",
    [
        pytest.param(None, id="no_fp8"),
        pytest.param(
            recipe.MXFP8BlockScaling,
            marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8),
            id="mxfp8",
        ),
    ],
)
@pytest.mark.parametrize("fp8_model_params", [False, True])
@pytest.mark.parametrize("use_bias", [False, True])
def test_sanity_linear_with_fp8_model_params(
    dtype: torch.dtype,
    num_tokens: int,
    fp8_recipe_cls: Optional[type[recipe.Recipe]],
    fp8_model_params: bool,
    use_bias: bool,
) -> None:
    """Exercise Linear with optional MXFP8 activation and model-parameter quantization."""

    config = _MODEL_CONFIGS["small"]
    fp8_recipe = _make_recipe(fp8_recipe_cls)
    _skip_unsupported_fp8_config(config, fp8_recipe)
    ffn_hidden_size = 4 * config.hidden_size

    with quantized_model_init(
        enabled=fp8_recipe is not None and fp8_model_params,
        recipe=fp8_recipe,
    ):
        te_linear = Linear(
            config.hidden_size,
            ffn_hidden_size,
            bias=use_bias,
            params_dtype=dtype,
            device="npu",
        )

    inp = _randn_npu(
        (num_tokens, config.hidden_size),
        dtype=dtype,
        requires_grad=True,
    )
    with autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
        out = te_linear(inp)
    loss = out.sum()
    loss.backward()
    _sync()

    assert out.shape == (num_tokens, ffn_hidden_size)


@pytest.mark.parametrize("dtype", _DTYPES)
@pytest.mark.parametrize("num_tokens_per_nonempty_group", [0, 16])
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("use_bias", [False, True])
@pytest.mark.parametrize("fp8_recipe_cls", [pytest.param(None, id="no_fp8")])
def test_sanity_grouped_linear(
    dtype: torch.dtype,
    num_tokens_per_nonempty_group: int,
    empty_split: str,
    use_bias: bool,
    fp8_recipe_cls: Optional[type[recipe.Recipe]],
) -> None:
    """Adapted from NVIDIA test_sanity_grouped_linear."""

    config = _MODEL_CONFIGS["small"]
    num_gemms = 4
    ffn_hidden_size = 4 * config.hidden_size
    fp8_recipe = _make_recipe(fp8_recipe_cls)

    te_grouped_linear = GroupedLinear(
        num_gemms,
        config.hidden_size,
        ffn_hidden_size,
        bias=use_bias,
        params_dtype=dtype,
        device="npu",
    )

    m_splits = [num_tokens_per_nonempty_group] * num_gemms
    if empty_split == "first":
        m_splits[0] = 0
    elif empty_split == "last":
        m_splits[-1] = 0
    else:
        m_splits[num_gemms // 2] = 0
    num_tokens = sum(m_splits)

    inp = _randn_npu(
        (num_tokens, config.hidden_size),
        dtype=dtype,
        requires_grad=True,
    )

    with autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
        out = te_grouped_linear(inp, m_splits)
    loss = out.sum()
    loss.backward()
    _sync()

    assert out.shape == (num_tokens, ffn_hidden_size)


def test_model_multiple_cast() -> None:
    """Adapted from NVIDIA test_model_multiple_cast."""

    inp = torch.zeros((16, 16), device=_npu_device())
    module = Linear(16, 32, device="npu")

    out = module(inp)
    assert out.dtype == torch.float32

    module.half()
    inp = inp.half()
    out = module(inp)
    assert out.dtype == torch.float16


@pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8)
def test_quantized_model_init_high_precision_init_val() -> None:
    """Adapted from NVIDIA test_quantized_model_init_high_precision_init_val."""

    with quantized_model_init(preserve_high_precision_init_val=True):
        module = Linear(128, 128, params_dtype=torch.bfloat16, device="npu")

    weight = module.weight
    assert isinstance(weight, QuantizedTensor)
    assert hasattr(weight, "_high_precision_init_val")
    assert hasattr(weight, "get_high_precision_init_val")
    assert hasattr(weight, "clear_high_precision_init_val")

    high_precision = weight.get_high_precision_init_val()
    assert high_precision.device.type == "cpu"

    new_weight = weight._get_quantizer().make_empty(
        shape=weight.shape,
        dtype=weight.dtype,
        device=weight.device,
    )
    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)

    torch.testing.assert_close(
        new_weight.dequantize(dtype=weight.dtype),
        weight.dequantize(dtype=weight.dtype),
        rtol=0,
        atol=0,
    )


_NORMALIZATIONS = ["LayerNorm", "RMSNorm"]
_ACTIVATIONS = ["gelu", "swiglu"]


def _make_layernorm_linear(
    config: ModelConfig,
    dtype: torch.dtype,
    *,
    zero_centered_gamma: bool = False,
    normalization: str = "LayerNorm",
) -> "LayerNormLinear":  # noqa: F821
    from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear

    sigma = 0.023
    init_method = init_method_normal(sigma)
    return LayerNormLinear(
        config.hidden_size,
        config.hidden_size * 3,
        init_method=init_method,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
        params_dtype=dtype,
        device="npu",
    )


@pytest.mark.parametrize("dtype", _DTYPES)
@pytest.mark.parametrize(("model", "fp8_recipe_cls"), _module_recipe_cases())
@pytest.mark.parametrize(("skip_wgrad", "skip_dgrad"), _GRAD_CASES)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", _NORMALIZATIONS)
@pytest.mark.parametrize("microbatching", [False, True])
def test_sanity_layernorm_linear(
    dtype: torch.dtype,
    fp8_recipe_cls: Optional[type[recipe.Recipe]],
    model: str,
    skip_wgrad: bool,
    skip_dgrad: bool,
    zero_centered_gamma: bool,
    normalization: str,
    microbatching: bool,
) -> None:
    """Adapted from NVIDIA test_sanity_layernorm_linear.

    Exercises LayerNormLinear forward/backward with:
    - LayerNorm and RMSNorm normalizations
    - Optional zero_centered_gamma
    - BF16 and FP8 (MXFP8) compute paths
    - Microbatching (is_first_microbatch)
    - Selective gradient skip (wgrad / dgrad)
    """
    config = _MODEL_CONFIGS[model]
    fp8_recipe = _make_recipe(fp8_recipe_cls)
    _skip_unsupported_fp8_config(config, fp8_recipe)

    block = _make_layernorm_linear(
        config,
        dtype,
        zero_centered_gamma=zero_centered_gamma,
        normalization=normalization,
    )

    _run_common_module_sanity(
        block,
        dtype=dtype,
        config=config,
        fp8_recipe=fp8_recipe,
        skip_wgrad=skip_wgrad,
        skip_dgrad=skip_dgrad,
        microbatching=microbatching,
    )