"""Module-level sanity coverage adapted from NVIDIA's PyTorch sanity tests.
These tests intentionally exercise user-facing PyTorch modules instead of
``transformer_engine_torch``/``tex`` single operators. The NPU port keeps a few
backend-specific layout choices, so this file validates the high-level training
path: module construction, autocast recipe wiring, forward/backward execution,
gradient materialization, zero-token handling, and quantized parameter init.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Optional
import pytest
import torch
from transformer_engine.common import recipe
from transformer_engine.pytorch import (
GroupedLinear,
Linear,
autocast,
quantized_model_init,
)
from transformer_engine.pytorch.quantization import (
FP8GlobalStateManager,
is_fp8_available,
is_mxfp8_available,
)
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
from transformer_engine.pytorch.utils import init_method_normal
from utils import npu_available
@dataclass(frozen=True)
class ModelConfig:
max_seqlen_q: int
batch_size: int
hidden_size: int
num_layers: int = 2
_MODEL_CONFIGS = {
"small": ModelConfig(max_seqlen_q=32, batch_size=2, hidden_size=32),
"weird": ModelConfig(max_seqlen_q=37, batch_size=3, hidden_size=23),
}
_DTYPES = [torch.float16, torch.bfloat16]
_GRAD_CASES = [
pytest.param(False, False, id="all_grads"),
pytest.param(True, False, id="skip_wgrad"),
pytest.param(False, True, id="skip_dgrad"),
]
_fp8_available, _reason_for_no_fp8 = is_fp8_available(return_reason=True)
_mxfp8_available, _reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
pytestmark = pytest.mark.skipif(not npu_available(), reason="NPU device is required")
def _npu_device() -> torch.device:
return torch.device("npu")
def _randn_npu(shape, *, dtype: torch.dtype, requires_grad: bool = False) -> torch.Tensor:
"""Generate random test inputs without exercising NPU random kernels."""
return (
torch.randn(shape, dtype=dtype, device="cpu")
.to(device=_npu_device())
.requires_grad_(requires_grad)
)
def _seed() -> None:
seed = 1234
torch.manual_seed(seed)
torch.npu.manual_seed(seed)
def _sync() -> None:
torch.npu.synchronize()
def _scaled_init_method_normal(
sigma: float, num_layers: int
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Match NVIDIA sanity test intent without depending on CUDA-only helpers."""
return init_method_normal(sigma / (2.0 * num_layers) ** 0.5)
def _module_recipe_cases() -> list[pytest.ParameterSet]:
cases = [
pytest.param("small", None, id="small_no_fp8"),
pytest.param("weird", None, id="weird_no_fp8"),
]
if _fp8_available:
cases.extend(
[
pytest.param(
"small",
recipe.DelayedScaling,
id="small_delayed",
marks=pytest.mark.skipif(
not _fp8_available,
reason=_reason_for_no_fp8,
),
),
pytest.param(
"small",
recipe.Float8CurrentScaling,
id="small_current",
marks=pytest.mark.skipif(
not _fp8_available,
reason=_reason_for_no_fp8,
),
),
]
)
if _mxfp8_available:
cases.append(
pytest.param(
"small",
recipe.MXFP8BlockScaling,
id="small_mxfp8",
marks=pytest.mark.skipif(
not _mxfp8_available,
reason=_reason_for_no_mxfp8,
),
)
)
return cases
def _make_recipe(recipe_cls: Optional[type[recipe.Recipe]]) -> Optional[recipe.Recipe]:
return None if recipe_cls is None else recipe_cls()
def _is_fp8_supported(config: ModelConfig) -> bool:
return config.max_seqlen_q * config.batch_size % 16 == 0 and config.hidden_size % 16 == 0
def _skip_unsupported_fp8_config(config: ModelConfig, fp8_recipe: Optional[recipe.Recipe]) -> None:
if fp8_recipe is None:
return
if not _is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
def _disable_wgrads(module: torch.nn.Module) -> None:
for param in module.parameters():
param.requires_grad = False
@pytest.fixture(autouse=True)
def _reset_global_fp8_state():
_seed()
yield
FP8GlobalStateManager.reset()
def _first_tensor(output: torch.Tensor | tuple[torch.Tensor, ...]) -> torch.Tensor:
return output[0] if isinstance(output, tuple) else output
def _assert_common_backward(
module: torch.nn.Module,
inp: torch.Tensor,
*,
skip_wgrad: bool,
skip_dgrad: bool,
) -> None:
if not skip_dgrad:
assert inp.grad is not None, "Input gradient should be materialized."
if not skip_wgrad:
missing = [
name
for name, param in module.named_parameters()
if param.requires_grad and param.grad is None
]
assert not missing, f"Parameter gradients should be materialized: {missing}"
def _run_common_module_sanity(
module: torch.nn.Module,
*,
dtype: torch.dtype,
config: ModelConfig,
fp8_recipe: Optional[recipe.Recipe],
skip_wgrad: bool,
skip_dgrad: bool,
microbatching: bool,
) -> None:
if skip_dgrad and skip_wgrad:
pytest.skip("No gradient computation; skipping to avoid PyTorch RuntimeError.")
_skip_unsupported_fp8_config(config, fp8_recipe)
inp = _randn_npu(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
dtype=dtype,
requires_grad=not skip_dgrad,
)
if skip_wgrad:
_disable_wgrads(module)
with autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
if microbatching:
_ = module(inp, is_first_microbatch=True)
out = module(inp, is_first_microbatch=False)
else:
out = module(inp)
out = _first_tensor(out)
assert out.shape[:-1] == inp.shape[:-1]
loss = out.sum()
loss.backward()
_sync()
_assert_common_backward(
module,
inp,
skip_wgrad=skip_wgrad,
skip_dgrad=skip_dgrad,
)
def _make_linear(config: ModelConfig, dtype: torch.dtype) -> Linear:
return Linear(
config.hidden_size,
config.hidden_size,
init_method=_scaled_init_method_normal(0.023, config.num_layers),
params_dtype=dtype,
device="npu",
)
@pytest.mark.parametrize("dtype", _DTYPES)
@pytest.mark.parametrize(("model", "fp8_recipe_cls"), _module_recipe_cases())
@pytest.mark.parametrize(("skip_wgrad", "skip_dgrad"), _GRAD_CASES)
@pytest.mark.parametrize("microbatching", [False, True])
def test_sanity_linear(
dtype: torch.dtype,
fp8_recipe_cls: Optional[type[recipe.Recipe]],
model: str,
skip_wgrad: bool,
skip_dgrad: bool,
microbatching: bool,
) -> None:
"""Adapted from NVIDIA test_sanity_linear."""
config = _MODEL_CONFIGS[model]
fp8_recipe = _make_recipe(fp8_recipe_cls)
block = _make_linear(config, dtype)
_run_common_module_sanity(
block,
dtype=dtype,
config=config,
fp8_recipe=fp8_recipe,
skip_wgrad=skip_wgrad,
skip_dgrad=skip_dgrad,
microbatching=microbatching,
)
@pytest.mark.parametrize("dtype", _DTYPES)
@pytest.mark.parametrize("parallel_mode", [None, "row"])
@pytest.mark.parametrize("return_bias", [False, True])
def test_sanity_linear_bias_path(
dtype: torch.dtype,
parallel_mode: Optional[str],
return_bias: bool,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Check that Linear routes bias through GEMM or an unfused add like NVIDIA TE."""
from transformer_engine.pytorch.module import linear as linear_module
config = _MODEL_CONFIGS["small"]
block = Linear(
config.hidden_size,
config.hidden_size,
bias=True,
return_bias=return_bias,
parallel_mode=parallel_mode,
params_dtype=dtype,
device="npu",
)
assert block.apply_bias == (not return_bias)
assert block.gemm_bias_unfused_add == (parallel_mode == "row" and not return_bias)
seen_gemm_biases = []
original_general_gemm = linear_module.general_gemm
def _record_general_gemm(*args, **kwargs):
seen_gemm_biases.append(kwargs.get("bias"))
return original_general_gemm(*args, **kwargs)
monkeypatch.setattr(linear_module, "general_gemm", _record_general_gemm)
inp = _randn_npu(
(config.max_seqlen_q, config.hidden_size),
dtype=dtype,
requires_grad=True,
)
result = block(inp)
if return_bias:
out, returned_bias = result
assert returned_bias.shape == block.bias.shape
assert returned_bias.dtype == dtype
else:
out = result
expected = torch.matmul(inp, block.weight.t())
if not return_bias:
expected = expected + block.bias
torch.testing.assert_close(out, expected, rtol=1e-2, atol=1e-2)
assert seen_gemm_biases, "Linear forward should route through general_gemm."
expect_gemm_bias = block.apply_bias and not block.gemm_bias_unfused_add
assert (seen_gemm_biases[0] is not None) == expect_gemm_bias
out.sum().backward()
_sync()
if return_bias:
assert block.bias.grad is None
else:
assert block.bias.grad is not None
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", [32])
@pytest.mark.parametrize(
"fp8_recipe_cls",
[
pytest.param(None, id="no_fp8"),
pytest.param(
recipe.MXFP8BlockScaling,
marks=pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8),
id="mxfp8",
),
],
)
@pytest.mark.parametrize("fp8_model_params", [False, True])
@pytest.mark.parametrize("use_bias", [False, True])
def test_sanity_linear_with_fp8_model_params(
dtype: torch.dtype,
num_tokens: int,
fp8_recipe_cls: Optional[type[recipe.Recipe]],
fp8_model_params: bool,
use_bias: bool,
) -> None:
"""Exercise Linear with optional MXFP8 activation and model-parameter quantization."""
config = _MODEL_CONFIGS["small"]
fp8_recipe = _make_recipe(fp8_recipe_cls)
_skip_unsupported_fp8_config(config, fp8_recipe)
ffn_hidden_size = 4 * config.hidden_size
with quantized_model_init(
enabled=fp8_recipe is not None and fp8_model_params,
recipe=fp8_recipe,
):
te_linear = Linear(
config.hidden_size,
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype,
device="npu",
)
inp = _randn_npu(
(num_tokens, config.hidden_size),
dtype=dtype,
requires_grad=True,
)
with autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
out = te_linear(inp)
loss = out.sum()
loss.backward()
_sync()
assert out.shape == (num_tokens, ffn_hidden_size)
@pytest.mark.parametrize("dtype", _DTYPES)
@pytest.mark.parametrize("num_tokens_per_nonempty_group", [0, 16])
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("use_bias", [False, True])
@pytest.mark.parametrize("fp8_recipe_cls", [pytest.param(None, id="no_fp8")])
def test_sanity_grouped_linear(
dtype: torch.dtype,
num_tokens_per_nonempty_group: int,
empty_split: str,
use_bias: bool,
fp8_recipe_cls: Optional[type[recipe.Recipe]],
) -> None:
"""Adapted from NVIDIA test_sanity_grouped_linear."""
config = _MODEL_CONFIGS["small"]
num_gemms = 4
ffn_hidden_size = 4 * config.hidden_size
fp8_recipe = _make_recipe(fp8_recipe_cls)
te_grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype,
device="npu",
)
m_splits = [num_tokens_per_nonempty_group] * num_gemms
if empty_split == "first":
m_splits[0] = 0
elif empty_split == "last":
m_splits[-1] = 0
else:
m_splits[num_gemms // 2] = 0
num_tokens = sum(m_splits)
inp = _randn_npu(
(num_tokens, config.hidden_size),
dtype=dtype,
requires_grad=True,
)
with autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
out = te_grouped_linear(inp, m_splits)
loss = out.sum()
loss.backward()
_sync()
assert out.shape == (num_tokens, ffn_hidden_size)
def test_model_multiple_cast() -> None:
"""Adapted from NVIDIA test_model_multiple_cast."""
inp = torch.zeros((16, 16), device=_npu_device())
module = Linear(16, 32, device="npu")
out = module(inp)
assert out.dtype == torch.float32
module.half()
inp = inp.half()
out = module(inp)
assert out.dtype == torch.float16
@pytest.mark.skipif(not _mxfp8_available, reason=_reason_for_no_mxfp8)
def test_quantized_model_init_high_precision_init_val() -> None:
"""Adapted from NVIDIA test_quantized_model_init_high_precision_init_val."""
with quantized_model_init(preserve_high_precision_init_val=True):
module = Linear(128, 128, params_dtype=torch.bfloat16, device="npu")
weight = module.weight
assert isinstance(weight, QuantizedTensor)
assert hasattr(weight, "_high_precision_init_val")
assert hasattr(weight, "get_high_precision_init_val")
assert hasattr(weight, "clear_high_precision_init_val")
high_precision = weight.get_high_precision_init_val()
assert high_precision.device.type == "cpu"
new_weight = weight._get_quantizer().make_empty(
shape=weight.shape,
dtype=weight.dtype,
device=weight.device,
)
weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)
torch.testing.assert_close(
new_weight.dequantize(dtype=weight.dtype),
weight.dequantize(dtype=weight.dtype),
rtol=0,
atol=0,
)
_NORMALIZATIONS = ["LayerNorm", "RMSNorm"]
_ACTIVATIONS = ["gelu", "swiglu"]
def _make_layernorm_linear(
config: ModelConfig,
dtype: torch.dtype,
*,
zero_centered_gamma: bool = False,
normalization: str = "LayerNorm",
) -> "LayerNormLinear":
from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear
sigma = 0.023
init_method = init_method_normal(sigma)
return LayerNormLinear(
config.hidden_size,
config.hidden_size * 3,
init_method=init_method,
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
params_dtype=dtype,
device="npu",
)
@pytest.mark.parametrize("dtype", _DTYPES)
@pytest.mark.parametrize(("model", "fp8_recipe_cls"), _module_recipe_cases())
@pytest.mark.parametrize(("skip_wgrad", "skip_dgrad"), _GRAD_CASES)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", _NORMALIZATIONS)
@pytest.mark.parametrize("microbatching", [False, True])
def test_sanity_layernorm_linear(
dtype: torch.dtype,
fp8_recipe_cls: Optional[type[recipe.Recipe]],
model: str,
skip_wgrad: bool,
skip_dgrad: bool,
zero_centered_gamma: bool,
normalization: str,
microbatching: bool,
) -> None:
"""Adapted from NVIDIA test_sanity_layernorm_linear.
Exercises LayerNormLinear forward/backward with:
- LayerNorm and RMSNorm normalizations
- Optional zero_centered_gamma
- BF16 and FP8 (MXFP8) compute paths
- Microbatching (is_first_microbatch)
- Selective gradient skip (wgrad / dgrad)
"""
config = _MODEL_CONFIGS[model]
fp8_recipe = _make_recipe(fp8_recipe_cls)
_skip_unsupported_fp8_config(config, fp8_recipe)
block = _make_layernorm_linear(
config,
dtype,
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
)
_run_common_module_sanity(
block,
dtype=dtype,
config=config,
fp8_recipe=fp8_recipe,
skip_wgrad=skip_wgrad,
skip_dgrad=skip_dgrad,
microbatching=microbatching,
)